diff options
Diffstat (limited to 'net')
517 files changed, 19336 insertions, 10626 deletions
diff --git a/net/6lowpan/debugfs.c b/net/6lowpan/debugfs.c index 6c152f9ea26e..536aae52eead 100644 --- a/net/6lowpan/debugfs.c +++ b/net/6lowpan/debugfs.c @@ -41,9 +41,9 @@ static int lowpan_ctx_flag_active_get(void *data, u64 *val) return 0; } -DEFINE_SIMPLE_ATTRIBUTE(lowpan_ctx_flag_active_fops, - lowpan_ctx_flag_active_get, - lowpan_ctx_flag_active_set, "%llu\n"); +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_ctx_flag_active_fops, + lowpan_ctx_flag_active_get, + lowpan_ctx_flag_active_set, "%llu\n"); static int lowpan_ctx_flag_c_set(void *data, u64 val) { @@ -66,8 +66,8 @@ static int lowpan_ctx_flag_c_get(void *data, u64 *val) return 0; } -DEFINE_SIMPLE_ATTRIBUTE(lowpan_ctx_flag_c_fops, lowpan_ctx_flag_c_get, - lowpan_ctx_flag_c_set, "%llu\n"); +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_ctx_flag_c_fops, lowpan_ctx_flag_c_get, + lowpan_ctx_flag_c_set, "%llu\n"); static int lowpan_ctx_plen_set(void *data, u64 val) { @@ -97,8 +97,8 @@ static int lowpan_ctx_plen_get(void *data, u64 *val) return 0; } -DEFINE_SIMPLE_ATTRIBUTE(lowpan_ctx_plen_fops, lowpan_ctx_plen_get, - lowpan_ctx_plen_set, "%llu\n"); +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_ctx_plen_fops, lowpan_ctx_plen_get, + lowpan_ctx_plen_set, "%llu\n"); static int lowpan_ctx_pfx_show(struct seq_file *file, void *offset) { @@ -184,15 +184,15 @@ static int lowpan_dev_debugfs_ctx_init(struct net_device *dev, if (!root) return -EINVAL; - dentry = debugfs_create_file("active", 0644, root, - &ldev->ctx.table[id], - &lowpan_ctx_flag_active_fops); + dentry = debugfs_create_file_unsafe("active", 0644, root, + &ldev->ctx.table[id], + &lowpan_ctx_flag_active_fops); if (!dentry) return -EINVAL; - dentry = debugfs_create_file("compression", 0644, root, - &ldev->ctx.table[id], - &lowpan_ctx_flag_c_fops); + dentry = debugfs_create_file_unsafe("compression", 0644, root, + &ldev->ctx.table[id], + &lowpan_ctx_flag_c_fops); if (!dentry) return -EINVAL; @@ -202,9 +202,9 @@ static int lowpan_dev_debugfs_ctx_init(struct net_device *dev, if (!dentry) return -EINVAL; - dentry = debugfs_create_file("prefix_len", 0644, root, - &ldev->ctx.table[id], - &lowpan_ctx_plen_fops); + dentry = debugfs_create_file_unsafe("prefix_len", 0644, root, + &ldev->ctx.table[id], + &lowpan_ctx_plen_fops); if (!dentry) return -EINVAL; @@ -245,8 +245,8 @@ static int lowpan_short_addr_get(void *data, u64 *val) return 0; } -DEFINE_SIMPLE_ATTRIBUTE(lowpan_short_addr_fops, lowpan_short_addr_get, - NULL, "0x%04llx\n"); +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_short_addr_fops, lowpan_short_addr_get, NULL, + "0x%04llx\n"); static int lowpan_dev_debugfs_802154_init(const struct net_device *dev, struct lowpan_dev *ldev) @@ -260,9 +260,9 @@ static int lowpan_dev_debugfs_802154_init(const struct net_device *dev, if (!root) return -EINVAL; - dentry = debugfs_create_file("short_addr", 0444, root, - lowpan_802154_dev(dev)->wdev->ieee802154_ptr, - &lowpan_short_addr_fops); + dentry = debugfs_create_file_unsafe("short_addr", 0444, root, + lowpan_802154_dev(dev)->wdev->ieee802154_ptr, + &lowpan_short_addr_fops); if (!dentry) return -EINVAL; diff --git a/net/8021q/vlan_dev.c b/net/8021q/vlan_dev.c index b2d9c8f27cd7..15293c2a5dd8 100644 --- a/net/8021q/vlan_dev.c +++ b/net/8021q/vlan_dev.c @@ -31,7 +31,6 @@ #include <linux/ethtool.h> #include <linux/phy.h> #include <net/arp.h> -#include <net/switchdev.h> #include "vlan.h" #include "vlanproc.h" diff --git a/net/Kconfig b/net/Kconfig index 5cb9de1aaf88..1efe1f9ee492 100644 --- a/net/Kconfig +++ b/net/Kconfig @@ -403,7 +403,7 @@ config LWTUNNEL config LWTUNNEL_BPF bool "Execute BPF program as route nexthop action" - depends on LWTUNNEL + depends on LWTUNNEL && INET default y if LWTUNNEL=y ---help--- Allows to run BPF programs as a nexthop action following a route @@ -429,21 +429,12 @@ config NET_SOCK_MSG with the help of BPF programs. config NET_DEVLINK - tristate "Network physical/parent device Netlink interface" + bool "Network physical/parent device Netlink interface" help Network physical/parent device Netlink interface provides infrastructure to support access to physical chip-wide config and monitoring. -config MAY_USE_DEVLINK - tristate - default m if NET_DEVLINK=m - default y if NET_DEVLINK=y || NET_DEVLINK=n - help - Drivers using the devlink infrastructure should have a dependency - on MAY_USE_DEVLINK to ensure they do not cause link errors when - devlink is a loadable module and the driver using it is built-in. - config PAGE_POOL bool diff --git a/net/Makefile b/net/Makefile index bdaf53925acd..449fc0b221f8 100644 --- a/net/Makefile +++ b/net/Makefile @@ -18,7 +18,7 @@ obj-$(CONFIG_NETFILTER) += netfilter/ obj-$(CONFIG_INET) += ipv4/ obj-$(CONFIG_TLS) += tls/ obj-$(CONFIG_XFRM) += xfrm/ -obj-$(CONFIG_UNIX) += unix/ +obj-$(CONFIG_UNIX_SCM) += unix/ obj-$(CONFIG_NET) += ipv6/ obj-$(CONFIG_BPFILTER) += bpfilter/ obj-$(CONFIG_PACKET) += packet/ diff --git a/net/appletalk/atalk_proc.c b/net/appletalk/atalk_proc.c index 8006295f8bd7..77f203f1febc 100644 --- a/net/appletalk/atalk_proc.c +++ b/net/appletalk/atalk_proc.c @@ -210,56 +210,34 @@ static const struct seq_operations atalk_seq_socket_ops = { .show = atalk_seq_socket_show, }; -static struct proc_dir_entry *atalk_proc_dir; - int __init atalk_proc_init(void) { - struct proc_dir_entry *p; - int rc = -ENOMEM; + if (!proc_mkdir("atalk", init_net.proc_net)) + return -ENOMEM; - atalk_proc_dir = proc_mkdir("atalk", init_net.proc_net); - if (!atalk_proc_dir) + if (!proc_create_seq("atalk/interface", 0444, init_net.proc_net, + &atalk_seq_interface_ops)) goto out; - p = proc_create_seq("interface", 0444, atalk_proc_dir, - &atalk_seq_interface_ops); - if (!p) - goto out_interface; - - p = proc_create_seq("route", 0444, atalk_proc_dir, - &atalk_seq_route_ops); - if (!p) - goto out_route; + if (!proc_create_seq("atalk/route", 0444, init_net.proc_net, + &atalk_seq_route_ops)) + goto out; - p = proc_create_seq("socket", 0444, atalk_proc_dir, - &atalk_seq_socket_ops); - if (!p) - goto out_socket; + if (!proc_create_seq("atalk/socket", 0444, init_net.proc_net, + &atalk_seq_socket_ops)) + goto out; - p = proc_create_seq_private("arp", 0444, atalk_proc_dir, &aarp_seq_ops, - sizeof(struct aarp_iter_state), NULL); - if (!p) - goto out_arp; + if (!proc_create_seq_private("atalk/arp", 0444, init_net.proc_net, + &aarp_seq_ops, + sizeof(struct aarp_iter_state), NULL)) + goto out; - rc = 0; out: - return rc; -out_arp: - remove_proc_entry("socket", atalk_proc_dir); -out_socket: - remove_proc_entry("route", atalk_proc_dir); -out_route: - remove_proc_entry("interface", atalk_proc_dir); -out_interface: - remove_proc_entry("atalk", init_net.proc_net); - goto out; + remove_proc_subtree("atalk", init_net.proc_net); + return -ENOMEM; } -void __exit atalk_proc_exit(void) +void atalk_proc_exit(void) { - remove_proc_entry("interface", atalk_proc_dir); - remove_proc_entry("route", atalk_proc_dir); - remove_proc_entry("socket", atalk_proc_dir); - remove_proc_entry("arp", atalk_proc_dir); - remove_proc_entry("atalk", init_net.proc_net); + remove_proc_subtree("atalk", init_net.proc_net); } diff --git a/net/appletalk/ddp.c b/net/appletalk/ddp.c index 9b6bc5abe946..795fbc6c06aa 100644 --- a/net/appletalk/ddp.c +++ b/net/appletalk/ddp.c @@ -1910,12 +1910,16 @@ static const char atalk_err_snap[] __initconst = /* Called by proto.c on kernel start up */ static int __init atalk_init(void) { - int rc = proto_register(&ddp_proto, 0); + int rc; - if (rc != 0) + rc = proto_register(&ddp_proto, 0); + if (rc) goto out; - (void)sock_register(&atalk_family_ops); + rc = sock_register(&atalk_family_ops); + if (rc) + goto out_proto; + ddp_dl = register_snap_client(ddp_snap_id, atalk_rcv); if (!ddp_dl) printk(atalk_err_snap); @@ -1923,12 +1927,33 @@ static int __init atalk_init(void) dev_add_pack(<alk_packet_type); dev_add_pack(&ppptalk_packet_type); - register_netdevice_notifier(&ddp_notifier); + rc = register_netdevice_notifier(&ddp_notifier); + if (rc) + goto out_sock; + aarp_proto_init(); - atalk_proc_init(); - atalk_register_sysctl(); + rc = atalk_proc_init(); + if (rc) + goto out_aarp; + + rc = atalk_register_sysctl(); + if (rc) + goto out_proc; out: return rc; +out_proc: + atalk_proc_exit(); +out_aarp: + aarp_cleanup_module(); + unregister_netdevice_notifier(&ddp_notifier); +out_sock: + dev_remove_pack(&ppptalk_packet_type); + dev_remove_pack(<alk_packet_type); + unregister_snap_client(ddp_dl); + sock_unregister(PF_APPLETALK); +out_proto: + proto_unregister(&ddp_proto); + goto out; } module_init(atalk_init); diff --git a/net/appletalk/sysctl_net_atalk.c b/net/appletalk/sysctl_net_atalk.c index c744a853fa5f..d945b7c0176d 100644 --- a/net/appletalk/sysctl_net_atalk.c +++ b/net/appletalk/sysctl_net_atalk.c @@ -45,9 +45,12 @@ static struct ctl_table atalk_table[] = { static struct ctl_table_header *atalk_table_header; -void atalk_register_sysctl(void) +int __init atalk_register_sysctl(void) { atalk_table_header = register_net_sysctl(&init_net, "net/appletalk", atalk_table); + if (!atalk_table_header) + return -ENOMEM; + return 0; } void atalk_unregister_sysctl(void) diff --git a/net/atm/proc.c b/net/atm/proc.c index 0b0495a41bbe..d79221fd4dae 100644 --- a/net/atm/proc.c +++ b/net/atm/proc.c @@ -134,7 +134,8 @@ static void vcc_seq_stop(struct seq_file *seq, void *v) static void *vcc_seq_next(struct seq_file *seq, void *v, loff_t *pos) { v = vcc_walk(seq, 1); - *pos += !!PTR_ERR(v); + if (v) + (*pos)++; return v; } diff --git a/net/atm/resources.c b/net/atm/resources.c index bada395ecdb1..889349c6d90d 100644 --- a/net/atm/resources.c +++ b/net/atm/resources.c @@ -203,13 +203,9 @@ int atm_dev_ioctl(unsigned int cmd, void __user *arg, int compat) int __user *sioc_len; int __user *iobuf_len; -#ifndef CONFIG_COMPAT - compat = 0; /* Just so the compiler _knows_ */ -#endif - switch (cmd) { case ATM_GETNAMES: - if (compat) { + if (IS_ENABLED(CONFIG_COMPAT) && compat) { #ifdef CONFIG_COMPAT struct compat_atm_iobuf __user *ciobuf = arg; compat_uptr_t cbuf; @@ -253,7 +249,7 @@ int atm_dev_ioctl(unsigned int cmd, void __user *arg, int compat) break; } - if (compat) { + if (IS_ENABLED(CONFIG_COMPAT) && compat) { #ifdef CONFIG_COMPAT struct compat_atmif_sioc __user *csioc = arg; compat_uptr_t carg; @@ -417,7 +413,7 @@ int atm_dev_ioctl(unsigned int cmd, void __user *arg, int compat) } /* fall through */ default: - if (compat) { + if (IS_ENABLED(CONFIG_COMPAT) && compat) { #ifdef CONFIG_COMPAT if (!dev->ops->compat_ioctl) { error = -EINVAL; diff --git a/net/ax25/ax25_ip.c b/net/ax25/ax25_ip.c index 70417e9b932d..314bbc8010fb 100644 --- a/net/ax25/ax25_ip.c +++ b/net/ax25/ax25_ip.c @@ -114,6 +114,7 @@ netdev_tx_t ax25_ip_xmit(struct sk_buff *skb) dst = (ax25_address *)(bp + 1); src = (ax25_address *)(bp + 8); + ax25_route_lock_use(); route = ax25_get_route(dst, NULL); if (route) { digipeat = route->digipeat; @@ -206,9 +207,8 @@ netdev_tx_t ax25_ip_xmit(struct sk_buff *skb) ax25_queue_xmit(skb, dev); put: - if (route) - ax25_put_route(route); + ax25_route_lock_unuse(); return NETDEV_TX_OK; } diff --git a/net/ax25/ax25_route.c b/net/ax25/ax25_route.c index a0eff323af12..66f74c85cf6b 100644 --- a/net/ax25/ax25_route.c +++ b/net/ax25/ax25_route.c @@ -40,7 +40,7 @@ #include <linux/export.h> static ax25_route *ax25_route_list; -static DEFINE_RWLOCK(ax25_route_lock); +DEFINE_RWLOCK(ax25_route_lock); void ax25_rt_device_down(struct net_device *dev) { @@ -335,6 +335,7 @@ const struct seq_operations ax25_rt_seqops = { * Find AX.25 route * * Only routes with a reference count of zero can be destroyed. + * Must be called with ax25_route_lock read locked. */ ax25_route *ax25_get_route(ax25_address *addr, struct net_device *dev) { @@ -342,7 +343,6 @@ ax25_route *ax25_get_route(ax25_address *addr, struct net_device *dev) ax25_route *ax25_def_rt = NULL; ax25_route *ax25_rt; - read_lock(&ax25_route_lock); /* * Bind to the physical interface we heard them on, or the default * route if none is found; @@ -365,11 +365,6 @@ ax25_route *ax25_get_route(ax25_address *addr, struct net_device *dev) if (ax25_spe_rt != NULL) ax25_rt = ax25_spe_rt; - if (ax25_rt != NULL) - ax25_hold_route(ax25_rt); - - read_unlock(&ax25_route_lock); - return ax25_rt; } @@ -400,9 +395,12 @@ int ax25_rt_autobind(ax25_cb *ax25, ax25_address *addr) ax25_route *ax25_rt; int err = 0; - if ((ax25_rt = ax25_get_route(addr, NULL)) == NULL) + ax25_route_lock_use(); + ax25_rt = ax25_get_route(addr, NULL); + if (!ax25_rt) { + ax25_route_lock_unuse(); return -EHOSTUNREACH; - + } if ((ax25->ax25_dev = ax25_dev_ax25dev(ax25_rt->dev)) == NULL) { err = -EHOSTUNREACH; goto put; @@ -437,8 +435,7 @@ int ax25_rt_autobind(ax25_cb *ax25, ax25_address *addr) } put: - ax25_put_route(ax25_rt); - + ax25_route_lock_unuse(); return err; } diff --git a/net/batman-adv/Kconfig b/net/batman-adv/Kconfig index c386e6981416..a31db5e9ac8e 100644 --- a/net/batman-adv/Kconfig +++ b/net/batman-adv/Kconfig @@ -1,5 +1,5 @@ # SPDX-License-Identifier: GPL-2.0 -# Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +# Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: # # Marek Lindner, Simon Wunderlich # diff --git a/net/batman-adv/Makefile b/net/batman-adv/Makefile index 9b58160fe485..a887ecc3efa1 100644 --- a/net/batman-adv/Makefile +++ b/net/batman-adv/Makefile @@ -1,5 +1,5 @@ # SPDX-License-Identifier: GPL-2.0 -# Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +# Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: # # Marek Lindner, Simon Wunderlich # diff --git a/net/batman-adv/bat_algo.c b/net/batman-adv/bat_algo.c index ea309ad06175..7b7e15641fef 100644 --- a/net/batman-adv/bat_algo.c +++ b/net/batman-adv/bat_algo.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/bat_algo.h b/net/batman-adv/bat_algo.h index 534b790c3753..25e7bb51928c 100644 --- a/net/batman-adv/bat_algo.h +++ b/net/batman-adv/bat_algo.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Linus Lüssing * diff --git a/net/batman-adv/bat_iv_ogm.c b/net/batman-adv/bat_iv_ogm.c index f97e566f0402..de61091af666 100644 --- a/net/batman-adv/bat_iv_ogm.c +++ b/net/batman-adv/bat_iv_ogm.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/bat_iv_ogm.h b/net/batman-adv/bat_iv_ogm.h index 3dc6a7a43eb7..785f6666273c 100644 --- a/net/batman-adv/bat_iv_ogm.h +++ b/net/batman-adv/bat_iv_ogm.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/bat_v.c b/net/batman-adv/bat_v.c index 90e33f84d37a..445594ed58af 100644 --- a/net/batman-adv/bat_v.c +++ b/net/batman-adv/bat_v.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2013-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2013-2019 B.A.T.M.A.N. contributors: * * Linus Lüssing, Marek Lindner * diff --git a/net/batman-adv/bat_v.h b/net/batman-adv/bat_v.h index ec4a2a569750..465a4fc23354 100644 --- a/net/batman-adv/bat_v.h +++ b/net/batman-adv/bat_v.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Linus Lüssing * diff --git a/net/batman-adv/bat_v_elp.c b/net/batman-adv/bat_v_elp.c index e8090f099eb8..a9b7919c9de5 100644 --- a/net/batman-adv/bat_v_elp.c +++ b/net/batman-adv/bat_v_elp.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Linus Lüssing, Marek Lindner * @@ -104,6 +104,9 @@ static u32 batadv_v_elp_get_throughput(struct batadv_hardif_neigh_node *neigh) ret = cfg80211_get_station(real_netdev, neigh->addr, &sinfo); + /* free the TID stats immediately */ + cfg80211_sinfo_release_content(&sinfo); + dev_put(real_netdev); if (ret == -ENOENT) { /* Node is not associated anymore! It would be diff --git a/net/batman-adv/bat_v_elp.h b/net/batman-adv/bat_v_elp.h index e8c7b7fd290d..75f189ee4a1c 100644 --- a/net/batman-adv/bat_v_elp.h +++ b/net/batman-adv/bat_v_elp.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2013-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2013-2019 B.A.T.M.A.N. contributors: * * Linus Lüssing, Marek Lindner * diff --git a/net/batman-adv/bat_v_ogm.c b/net/batman-adv/bat_v_ogm.c index 2948b41b06d4..c9698ad41854 100644 --- a/net/batman-adv/bat_v_ogm.c +++ b/net/batman-adv/bat_v_ogm.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2013-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2013-2019 B.A.T.M.A.N. contributors: * * Antonio Quartulli * diff --git a/net/batman-adv/bat_v_ogm.h b/net/batman-adv/bat_v_ogm.h index e5be14c908c6..f67cf7ee06b2 100644 --- a/net/batman-adv/bat_v_ogm.h +++ b/net/batman-adv/bat_v_ogm.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2013-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2013-2019 B.A.T.M.A.N. contributors: * * Antonio Quartulli * diff --git a/net/batman-adv/bitarray.c b/net/batman-adv/bitarray.c index a296a4d851f5..63e134e763e3 100644 --- a/net/batman-adv/bitarray.c +++ b/net/batman-adv/bitarray.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2006-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2006-2019 B.A.T.M.A.N. contributors: * * Simon Wunderlich, Marek Lindner * diff --git a/net/batman-adv/bitarray.h b/net/batman-adv/bitarray.h index 48f683289531..f3a05ad9afad 100644 --- a/net/batman-adv/bitarray.h +++ b/net/batman-adv/bitarray.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2006-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2006-2019 B.A.T.M.A.N. contributors: * * Simon Wunderlich, Marek Lindner * diff --git a/net/batman-adv/bridge_loop_avoidance.c b/net/batman-adv/bridge_loop_avoidance.c index 5fdde2947802..ef39aabdb694 100644 --- a/net/batman-adv/bridge_loop_avoidance.c +++ b/net/batman-adv/bridge_loop_avoidance.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Simon Wunderlich * diff --git a/net/batman-adv/bridge_loop_avoidance.h b/net/batman-adv/bridge_loop_avoidance.h index 71f95a3e4d3f..31771c751efb 100644 --- a/net/batman-adv/bridge_loop_avoidance.h +++ b/net/batman-adv/bridge_loop_avoidance.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Simon Wunderlich * diff --git a/net/batman-adv/debugfs.c b/net/batman-adv/debugfs.c index d4a7702e48d8..3b9d1ad2f467 100644 --- a/net/batman-adv/debugfs.c +++ b/net/batman-adv/debugfs.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/debugfs.h b/net/batman-adv/debugfs.h index 8de018e5c577..c0b8694041ec 100644 --- a/net/batman-adv/debugfs.h +++ b/net/batman-adv/debugfs.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/distributed-arp-table.c b/net/batman-adv/distributed-arp-table.c index b9ffe1826527..310a4f353008 100644 --- a/net/batman-adv/distributed-arp-table.c +++ b/net/batman-adv/distributed-arp-table.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Antonio Quartulli * @@ -19,6 +19,7 @@ #include "distributed-arp-table.h" #include "main.h" +#include <asm/unaligned.h> #include <linux/atomic.h> #include <linux/bitops.h> #include <linux/byteorder/generic.h> @@ -29,6 +30,7 @@ #include <linux/if_ether.h> #include <linux/if_vlan.h> #include <linux/in.h> +#include <linux/ip.h> #include <linux/jiffies.h> #include <linux/kernel.h> #include <linux/kref.h> @@ -42,6 +44,7 @@ #include <linux/spinlock.h> #include <linux/stddef.h> #include <linux/string.h> +#include <linux/udp.h> #include <linux/workqueue.h> #include <net/arp.h> #include <net/genetlink.h> @@ -60,6 +63,49 @@ #include "translation-table.h" #include "tvlv.h" +enum batadv_bootpop { + BATADV_BOOTREPLY = 2, +}; + +enum batadv_boothtype { + BATADV_HTYPE_ETHERNET = 1, +}; + +enum batadv_dhcpoptioncode { + BATADV_DHCP_OPT_PAD = 0, + BATADV_DHCP_OPT_MSG_TYPE = 53, + BATADV_DHCP_OPT_END = 255, +}; + +enum batadv_dhcptype { + BATADV_DHCPACK = 5, +}; + +/* { 99, 130, 83, 99 } */ +#define BATADV_DHCP_MAGIC 1669485411 + +struct batadv_dhcp_packet { + __u8 op; + __u8 htype; + __u8 hlen; + __u8 hops; + __be32 xid; + __be16 secs; + __be16 flags; + __be32 ciaddr; + __be32 yiaddr; + __be32 siaddr; + __be32 giaddr; + __u8 chaddr[16]; + __u8 sname[64]; + __u8 file[128]; + __be32 magic; + __u8 options[0]; +}; + +#define BATADV_DHCP_YIADDR_LEN sizeof(((struct batadv_dhcp_packet *)0)->yiaddr) +#define BATADV_DHCP_CHADDR_LEN sizeof(((struct batadv_dhcp_packet *)0)->chaddr) + static void batadv_dat_purge(struct work_struct *work); /** @@ -1440,6 +1486,361 @@ out: } /** + * batadv_dat_check_dhcp_ipudp() - check skb for IP+UDP headers valid for DHCP + * @skb: the packet to check + * @ip_src: a buffer to store the IPv4 source address in + * + * Checks whether the given skb has an IP and UDP header valid for a DHCP + * message from a DHCP server. And if so, stores the IPv4 source address in + * the provided buffer. + * + * Return: True if valid, false otherwise. + */ +static bool +batadv_dat_check_dhcp_ipudp(struct sk_buff *skb, __be32 *ip_src) +{ + unsigned int offset = skb_network_offset(skb); + struct udphdr *udphdr, _udphdr; + struct iphdr *iphdr, _iphdr; + + iphdr = skb_header_pointer(skb, offset, sizeof(_iphdr), &_iphdr); + if (!iphdr || iphdr->version != 4 || iphdr->ihl * 4 < sizeof(_iphdr)) + return false; + + if (iphdr->protocol != IPPROTO_UDP) + return false; + + offset += iphdr->ihl * 4; + skb_set_transport_header(skb, offset); + + udphdr = skb_header_pointer(skb, offset, sizeof(_udphdr), &_udphdr); + if (!udphdr || udphdr->source != htons(67)) + return false; + + *ip_src = get_unaligned(&iphdr->saddr); + + return true; +} + +/** + * batadv_dat_check_dhcp() - examine packet for valid DHCP message + * @skb: the packet to check + * @proto: ethernet protocol hint (behind a potential vlan) + * @ip_src: a buffer to store the IPv4 source address in + * + * Checks whether the given skb is a valid DHCP packet. And if so, stores the + * IPv4 source address in the provided buffer. + * + * Caller needs to ensure that the skb network header is set correctly. + * + * Return: If skb is a valid DHCP packet, then returns its op code + * (e.g. BOOTREPLY vs. BOOTREQUEST). Otherwise returns -EINVAL. + */ +static int +batadv_dat_check_dhcp(struct sk_buff *skb, __be16 proto, __be32 *ip_src) +{ + __be32 *magic, _magic; + unsigned int offset; + struct { + __u8 op; + __u8 htype; + __u8 hlen; + __u8 hops; + } *dhcp_h, _dhcp_h; + + if (proto != htons(ETH_P_IP)) + return -EINVAL; + + if (!batadv_dat_check_dhcp_ipudp(skb, ip_src)) + return -EINVAL; + + offset = skb_transport_offset(skb) + sizeof(struct udphdr); + if (skb->len < offset + sizeof(struct batadv_dhcp_packet)) + return -EINVAL; + + dhcp_h = skb_header_pointer(skb, offset, sizeof(_dhcp_h), &_dhcp_h); + if (!dhcp_h || dhcp_h->htype != BATADV_HTYPE_ETHERNET || + dhcp_h->hlen != ETH_ALEN) + return -EINVAL; + + offset += offsetof(struct batadv_dhcp_packet, magic); + + magic = skb_header_pointer(skb, offset, sizeof(_magic), &_magic); + if (!magic || get_unaligned(magic) != htonl(BATADV_DHCP_MAGIC)) + return -EINVAL; + + return dhcp_h->op; +} + +/** + * batadv_dat_get_dhcp_message_type() - get message type of a DHCP packet + * @skb: the DHCP packet to parse + * + * Iterates over the DHCP options of the given DHCP packet to find a + * DHCP Message Type option and parse it. + * + * Caller needs to ensure that the given skb is a valid DHCP packet and + * that the skb transport header is set correctly. + * + * Return: The found DHCP message type value, if found. -EINVAL otherwise. + */ +static int batadv_dat_get_dhcp_message_type(struct sk_buff *skb) +{ + unsigned int offset = skb_transport_offset(skb) + sizeof(struct udphdr); + u8 *type, _type; + struct { + u8 type; + u8 len; + } *tl, _tl; + + offset += sizeof(struct batadv_dhcp_packet); + + while ((tl = skb_header_pointer(skb, offset, sizeof(_tl), &_tl))) { + if (tl->type == BATADV_DHCP_OPT_MSG_TYPE) + break; + + if (tl->type == BATADV_DHCP_OPT_END) + break; + + if (tl->type == BATADV_DHCP_OPT_PAD) + offset++; + else + offset += tl->len + sizeof(_tl); + } + + /* Option Overload Code not supported */ + if (!tl || tl->type != BATADV_DHCP_OPT_MSG_TYPE || + tl->len != sizeof(_type)) + return -EINVAL; + + offset += sizeof(_tl); + + type = skb_header_pointer(skb, offset, sizeof(_type), &_type); + if (!type) + return -EINVAL; + + return *type; +} + +/** + * batadv_dat_get_dhcp_yiaddr() - get yiaddr from a DHCP packet + * @skb: the DHCP packet to parse + * @buf: a buffer to store the yiaddr in + * + * Caller needs to ensure that the given skb is a valid DHCP packet and + * that the skb transport header is set correctly. + * + * Return: True on success, false otherwise. + */ +static bool batadv_dat_dhcp_get_yiaddr(struct sk_buff *skb, __be32 *buf) +{ + unsigned int offset = skb_transport_offset(skb) + sizeof(struct udphdr); + __be32 *yiaddr; + + offset += offsetof(struct batadv_dhcp_packet, yiaddr); + yiaddr = skb_header_pointer(skb, offset, BATADV_DHCP_YIADDR_LEN, buf); + + if (!yiaddr) + return false; + + if (yiaddr != buf) + *buf = get_unaligned(yiaddr); + + return true; +} + +/** + * batadv_dat_get_dhcp_chaddr() - get chaddr from a DHCP packet + * @skb: the DHCP packet to parse + * @buf: a buffer to store the chaddr in + * + * Caller needs to ensure that the given skb is a valid DHCP packet and + * that the skb transport header is set correctly. + * + * Return: True on success, false otherwise + */ +static bool batadv_dat_get_dhcp_chaddr(struct sk_buff *skb, u8 *buf) +{ + unsigned int offset = skb_transport_offset(skb) + sizeof(struct udphdr); + u8 *chaddr; + + offset += offsetof(struct batadv_dhcp_packet, chaddr); + chaddr = skb_header_pointer(skb, offset, BATADV_DHCP_CHADDR_LEN, buf); + + if (!chaddr) + return false; + + if (chaddr != buf) + memcpy(buf, chaddr, BATADV_DHCP_CHADDR_LEN); + + return true; +} + +/** + * batadv_dat_put_dhcp() - puts addresses from a DHCP packet into the DHT and + * DAT cache + * @bat_priv: the bat priv with all the soft interface information + * @chaddr: the DHCP client MAC address + * @yiaddr: the DHCP client IP address + * @hw_dst: the DHCP server MAC address + * @ip_dst: the DHCP server IP address + * @vid: VLAN identifier + * + * Adds given MAC/IP pairs to the local DAT cache and propagates them further + * into the DHT. + * + * For the DHT propagation, client MAC + IP will appear as the ARP Reply + * transmitter (and hw_dst/ip_dst as the target). + */ +static void batadv_dat_put_dhcp(struct batadv_priv *bat_priv, u8 *chaddr, + __be32 yiaddr, u8 *hw_dst, __be32 ip_dst, + unsigned short vid) +{ + struct sk_buff *skb; + + skb = batadv_dat_arp_create_reply(bat_priv, yiaddr, ip_dst, chaddr, + hw_dst, vid); + if (!skb) + return; + + skb_set_network_header(skb, ETH_HLEN); + + batadv_dat_entry_add(bat_priv, yiaddr, chaddr, vid); + batadv_dat_entry_add(bat_priv, ip_dst, hw_dst, vid); + + batadv_dat_send_data(bat_priv, skb, yiaddr, vid, BATADV_P_DAT_DHT_PUT); + batadv_dat_send_data(bat_priv, skb, ip_dst, vid, BATADV_P_DAT_DHT_PUT); + + consume_skb(skb); + + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from outgoing DHCPACK (server address): %pI4, %pM (vid: %i)\n", + &ip_dst, hw_dst, batadv_print_vid(vid)); + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from outgoing DHCPACK (client address): %pI4, %pM (vid: %i)\n", + &yiaddr, chaddr, batadv_print_vid(vid)); +} + +/** + * batadv_dat_check_dhcp_ack() - examine packet for valid DHCP message + * @skb: the packet to check + * @proto: ethernet protocol hint (behind a potential vlan) + * @ip_src: a buffer to store the IPv4 source address in + * @chaddr: a buffer to store the DHCP Client Hardware Address in + * @yiaddr: a buffer to store the DHCP Your IP Address in + * + * Checks whether the given skb is a valid DHCPACK. And if so, stores the + * IPv4 server source address (ip_src), client MAC address (chaddr) and client + * IPv4 address (yiaddr) in the provided buffers. + * + * Caller needs to ensure that the skb network header is set correctly. + * + * Return: True if the skb is a valid DHCPACK. False otherwise. + */ +static bool +batadv_dat_check_dhcp_ack(struct sk_buff *skb, __be16 proto, __be32 *ip_src, + u8 *chaddr, __be32 *yiaddr) +{ + int type; + + type = batadv_dat_check_dhcp(skb, proto, ip_src); + if (type != BATADV_BOOTREPLY) + return false; + + type = batadv_dat_get_dhcp_message_type(skb); + if (type != BATADV_DHCPACK) + return false; + + if (!batadv_dat_dhcp_get_yiaddr(skb, yiaddr)) + return false; + + if (!batadv_dat_get_dhcp_chaddr(skb, chaddr)) + return false; + + return true; +} + +/** + * batadv_dat_snoop_outgoing_dhcp_ack() - snoop DHCPACK and fill DAT with it + * @bat_priv: the bat priv with all the soft interface information + * @skb: the packet to snoop + * @proto: ethernet protocol hint (behind a potential vlan) + * @vid: VLAN identifier + * + * This function first checks whether the given skb is a valid DHCPACK. If + * so then its source MAC and IP as well as its DHCP Client Hardware Address + * field and DHCP Your IP Address field are added to the local DAT cache and + * propagated into the DHT. + * + * Caller needs to ensure that the skb mac and network headers are set + * correctly. + */ +void batadv_dat_snoop_outgoing_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, + __be16 proto, + unsigned short vid) +{ + u8 chaddr[BATADV_DHCP_CHADDR_LEN]; + __be32 ip_src, yiaddr; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + return; + + if (!batadv_dat_check_dhcp_ack(skb, proto, &ip_src, chaddr, &yiaddr)) + return; + + batadv_dat_put_dhcp(bat_priv, chaddr, yiaddr, eth_hdr(skb)->h_source, + ip_src, vid); +} + +/** + * batadv_dat_snoop_incoming_dhcp_ack() - snoop DHCPACK and fill DAT cache + * @bat_priv: the bat priv with all the soft interface information + * @skb: the packet to snoop + * @hdr_size: header size, up to the tail of the batman-adv header + * + * This function first checks whether the given skb is a valid DHCPACK. If + * so then its source MAC and IP as well as its DHCP Client Hardware Address + * field and DHCP Your IP Address field are added to the local DAT cache. + */ +void batadv_dat_snoop_incoming_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + u8 chaddr[BATADV_DHCP_CHADDR_LEN]; + struct ethhdr *ethhdr; + __be32 ip_src, yiaddr; + unsigned short vid; + __be16 proto; + u8 *hw_src; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + return; + + if (unlikely(!pskb_may_pull(skb, hdr_size + ETH_HLEN))) + return; + + ethhdr = (struct ethhdr *)(skb->data + hdr_size); + skb_set_network_header(skb, hdr_size + ETH_HLEN); + proto = ethhdr->h_proto; + + if (!batadv_dat_check_dhcp_ack(skb, proto, &ip_src, chaddr, &yiaddr)) + return; + + hw_src = ethhdr->h_source; + vid = batadv_dat_get_vid(skb, &hdr_size); + + batadv_dat_entry_add(bat_priv, yiaddr, chaddr, vid); + batadv_dat_entry_add(bat_priv, ip_src, hw_src, vid); + + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from incoming DHCPACK (server address): %pI4, %pM (vid: %i)\n", + &ip_src, hw_src, batadv_print_vid(vid)); + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from incoming DHCPACK (client address): %pI4, %pM (vid: %i)\n", + &yiaddr, chaddr, batadv_print_vid(vid)); +} + +/** * batadv_dat_drop_broadcast_packet() - check if an ARP request has to be * dropped (because the node has already obtained the reply via DAT) or not * @bat_priv: the bat priv with all the soft interface information diff --git a/net/batman-adv/distributed-arp-table.h b/net/batman-adv/distributed-arp-table.h index a04596028337..68c0ff321acd 100644 --- a/net/batman-adv/distributed-arp-table.h +++ b/net/batman-adv/distributed-arp-table.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2011-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2011-2019 B.A.T.M.A.N. contributors: * * Antonio Quartulli * @@ -46,6 +46,12 @@ void batadv_dat_snoop_outgoing_arp_reply(struct batadv_priv *bat_priv, struct sk_buff *skb); bool batadv_dat_snoop_incoming_arp_reply(struct batadv_priv *bat_priv, struct sk_buff *skb, int hdr_size); +void batadv_dat_snoop_outgoing_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, + __be16 proto, + unsigned short vid); +void batadv_dat_snoop_incoming_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size); bool batadv_dat_drop_broadcast_packet(struct batadv_priv *bat_priv, struct batadv_forw_packet *forw_packet); @@ -140,6 +146,19 @@ batadv_dat_snoop_incoming_arp_reply(struct batadv_priv *bat_priv, return false; } +static inline void +batadv_dat_snoop_outgoing_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, __be16 proto, + unsigned short vid) +{ +} + +static inline void +batadv_dat_snoop_incoming_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ +} + static inline bool batadv_dat_drop_broadcast_packet(struct batadv_priv *bat_priv, struct batadv_forw_packet *forw_packet) diff --git a/net/batman-adv/fragmentation.c b/net/batman-adv/fragmentation.c index 5b71a289d04f..b506d15b8230 100644 --- a/net/batman-adv/fragmentation.c +++ b/net/batman-adv/fragmentation.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2013-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2013-2019 B.A.T.M.A.N. contributors: * * Martin Hundebøll <martin@hundeboll.net> * diff --git a/net/batman-adv/fragmentation.h b/net/batman-adv/fragmentation.h index 944512e07782..abdac26579bf 100644 --- a/net/batman-adv/fragmentation.h +++ b/net/batman-adv/fragmentation.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2013-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2013-2019 B.A.T.M.A.N. contributors: * * Martin Hundebøll <martin@hundeboll.net> * diff --git a/net/batman-adv/gateway_client.c b/net/batman-adv/gateway_client.c index 9d8e5eda2314..f5811f61aa92 100644 --- a/net/batman-adv/gateway_client.c +++ b/net/batman-adv/gateway_client.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2009-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2009-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * @@ -47,7 +47,6 @@ #include <uapi/linux/batadv_packet.h> #include <uapi/linux/batman_adv.h> -#include "gateway_common.h" #include "hard-interface.h" #include "log.h" #include "netlink.h" diff --git a/net/batman-adv/gateway_client.h b/net/batman-adv/gateway_client.h index f0b86fcb2493..b5732c8be81a 100644 --- a/net/batman-adv/gateway_client.h +++ b/net/batman-adv/gateway_client.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2009-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2009-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/gateway_common.c b/net/batman-adv/gateway_common.c index 936c107f3199..e064de45e22c 100644 --- a/net/batman-adv/gateway_common.c +++ b/net/batman-adv/gateway_common.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2009-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2009-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * @@ -28,6 +28,7 @@ #include <linux/stddef.h> #include <linux/string.h> #include <uapi/linux/batadv_packet.h> +#include <uapi/linux/batman_adv.h> #include "gateway_client.h" #include "log.h" diff --git a/net/batman-adv/gateway_common.h b/net/batman-adv/gateway_common.h index 80afb2793687..128467a0fb89 100644 --- a/net/batman-adv/gateway_common.h +++ b/net/batman-adv/gateway_common.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2009-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2009-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * @@ -25,12 +25,6 @@ struct net_device; -enum batadv_gw_modes { - BATADV_GW_MODE_OFF, - BATADV_GW_MODE_CLIENT, - BATADV_GW_MODE_SERVER, -}; - /** * enum batadv_bandwidth_units - bandwidth unit types */ diff --git a/net/batman-adv/hard-interface.c b/net/batman-adv/hard-interface.c index 508f4416dfc9..96ef7c70b4d9 100644 --- a/net/batman-adv/hard-interface.c +++ b/net/batman-adv/hard-interface.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * @@ -20,7 +20,6 @@ #include "main.h" #include <linux/atomic.h> -#include <linux/bug.h> #include <linux/byteorder/generic.h> #include <linux/errno.h> #include <linux/gfp.h> @@ -179,8 +178,10 @@ static bool batadv_is_on_batman_iface(const struct net_device *net_dev) parent_dev = __dev_get_by_index((struct net *)parent_net, dev_get_iflink(net_dev)); /* if we got a NULL parent_dev there is something broken.. */ - if (WARN(!parent_dev, "Cannot find parent device")) + if (!parent_dev) { + pr_err("Cannot find parent device\n"); return false; + } if (batadv_mutual_parents(net_dev, net, parent_dev, parent_net)) return false; diff --git a/net/batman-adv/hard-interface.h b/net/batman-adv/hard-interface.h index d1c0f6189301..48de28c83401 100644 --- a/net/batman-adv/hard-interface.h +++ b/net/batman-adv/hard-interface.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/hash.c b/net/batman-adv/hash.c index 9194f4d891b1..56a08ce193d5 100644 --- a/net/batman-adv/hash.c +++ b/net/batman-adv/hash.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2006-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2006-2019 B.A.T.M.A.N. contributors: * * Simon Wunderlich, Marek Lindner * diff --git a/net/batman-adv/hash.h b/net/batman-adv/hash.h index 0e36fa1c7c3e..37507b6d4006 100644 --- a/net/batman-adv/hash.h +++ b/net/batman-adv/hash.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2006-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2006-2019 B.A.T.M.A.N. contributors: * * Simon Wunderlich, Marek Lindner * diff --git a/net/batman-adv/icmp_socket.c b/net/batman-adv/icmp_socket.c index 6d5859714f52..9859ababb82e 100644 --- a/net/batman-adv/icmp_socket.c +++ b/net/batman-adv/icmp_socket.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/icmp_socket.h b/net/batman-adv/icmp_socket.h index 958be22beda9..5f8926522ff0 100644 --- a/net/batman-adv/icmp_socket.h +++ b/net/batman-adv/icmp_socket.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/log.c b/net/batman-adv/log.c index 75f602e1ce94..3e610df8debf 100644 --- a/net/batman-adv/log.c +++ b/net/batman-adv/log.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/log.h b/net/batman-adv/log.h index 35f4f397ed57..660e9bcc85a2 100644 --- a/net/batman-adv/log.h +++ b/net/batman-adv/log.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/main.c b/net/batman-adv/main.c index d1ed839fd32b..75750870cf04 100644 --- a/net/batman-adv/main.c +++ b/net/batman-adv/main.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/main.h b/net/batman-adv/main.h index b572066325e4..3ed669d7dc6b 100644 --- a/net/batman-adv/main.h +++ b/net/batman-adv/main.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * @@ -25,7 +25,7 @@ #define BATADV_DRIVER_DEVICE "batman-adv" #ifndef BATADV_SOURCE_VERSION -#define BATADV_SOURCE_VERSION "2019.0" +#define BATADV_SOURCE_VERSION "2019.1" #endif /* B.A.T.M.A.N. parameters */ diff --git a/net/batman-adv/multicast.c b/net/batman-adv/multicast.c index 69244e4598f5..f91b1b6265cf 100644 --- a/net/batman-adv/multicast.c +++ b/net/batman-adv/multicast.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2014-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2014-2019 B.A.T.M.A.N. contributors: * * Linus Lüssing * @@ -674,7 +674,7 @@ static void batadv_mcast_mla_update(struct work_struct *work) */ static bool batadv_mcast_is_report_ipv4(struct sk_buff *skb) { - if (ip_mc_check_igmp(skb, NULL) < 0) + if (ip_mc_check_igmp(skb) < 0) return false; switch (igmp_hdr(skb)->type) { @@ -741,7 +741,7 @@ static int batadv_mcast_forw_mode_check_ipv4(struct batadv_priv *bat_priv, */ static bool batadv_mcast_is_report_ipv6(struct sk_buff *skb) { - if (ipv6_mc_check_mld(skb, NULL) < 0) + if (ipv6_mc_check_mld(skb) < 0) return false; switch (icmp6_hdr(skb)->icmp6_type) { diff --git a/net/batman-adv/multicast.h b/net/batman-adv/multicast.h index 3b04ab13f0eb..466013fe88af 100644 --- a/net/batman-adv/multicast.h +++ b/net/batman-adv/multicast.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2014-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2014-2019 B.A.T.M.A.N. contributors: * * Linus Lüssing * diff --git a/net/batman-adv/netlink.c b/net/batman-adv/netlink.c index 2dc3304cee54..67a58da2e6a0 100644 --- a/net/batman-adv/netlink.c +++ b/net/batman-adv/netlink.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2016-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2016-2019 B.A.T.M.A.N. contributors: * * Matthias Schiffer * @@ -20,13 +20,17 @@ #include "main.h" #include <linux/atomic.h> +#include <linux/bitops.h> +#include <linux/bug.h> #include <linux/byteorder/generic.h> #include <linux/cache.h> +#include <linux/err.h> #include <linux/errno.h> #include <linux/export.h> #include <linux/genetlink.h> #include <linux/gfp.h> #include <linux/if_ether.h> +#include <linux/if_vlan.h> #include <linux/init.h> #include <linux/kernel.h> #include <linux/list.h> @@ -47,21 +51,54 @@ #include "bridge_loop_avoidance.h" #include "distributed-arp-table.h" #include "gateway_client.h" +#include "gateway_common.h" #include "hard-interface.h" +#include "log.h" #include "multicast.h" +#include "network-coding.h" #include "originator.h" #include "soft-interface.h" #include "tp_meter.h" #include "translation-table.h" +struct net; + struct genl_family batadv_netlink_family; /* multicast groups */ enum batadv_netlink_multicast_groups { + BATADV_NL_MCGRP_CONFIG, BATADV_NL_MCGRP_TPMETER, }; +/** + * enum batadv_genl_ops_flags - flags for genl_ops's internal_flags + */ +enum batadv_genl_ops_flags { + /** + * @BATADV_FLAG_NEED_MESH: request requires valid soft interface in + * attribute BATADV_ATTR_MESH_IFINDEX and expects a pointer to it to be + * saved in info->user_ptr[0] + */ + BATADV_FLAG_NEED_MESH = BIT(0), + + /** + * @BATADV_FLAG_NEED_HARDIF: request requires valid hard interface in + * attribute BATADV_ATTR_HARD_IFINDEX and expects a pointer to it to be + * saved in info->user_ptr[1] + */ + BATADV_FLAG_NEED_HARDIF = BIT(1), + + /** + * @BATADV_FLAG_NEED_VLAN: request requires valid vlan in + * attribute BATADV_ATTR_VLANID and expects a pointer to it to be + * saved in info->user_ptr[1] + */ + BATADV_FLAG_NEED_VLAN = BIT(2), +}; + static const struct genl_multicast_group batadv_netlink_mcgrps[] = { + [BATADV_NL_MCGRP_CONFIG] = { .name = BATADV_NL_MCAST_GROUP_CONFIG }, [BATADV_NL_MCGRP_TPMETER] = { .name = BATADV_NL_MCAST_GROUP_TPMETER }, }; @@ -104,6 +141,26 @@ static const struct nla_policy batadv_netlink_policy[NUM_BATADV_ATTR] = { [BATADV_ATTR_DAT_CACHE_VID] = { .type = NLA_U16 }, [BATADV_ATTR_MCAST_FLAGS] = { .type = NLA_U32 }, [BATADV_ATTR_MCAST_FLAGS_PRIV] = { .type = NLA_U32 }, + [BATADV_ATTR_VLANID] = { .type = NLA_U16 }, + [BATADV_ATTR_AGGREGATED_OGMS_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_AP_ISOLATION_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_ISOLATION_MARK] = { .type = NLA_U32 }, + [BATADV_ATTR_ISOLATION_MASK] = { .type = NLA_U32 }, + [BATADV_ATTR_BONDING_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_FRAGMENTATION_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_GW_BANDWIDTH_DOWN] = { .type = NLA_U32 }, + [BATADV_ATTR_GW_BANDWIDTH_UP] = { .type = NLA_U32 }, + [BATADV_ATTR_GW_MODE] = { .type = NLA_U8 }, + [BATADV_ATTR_GW_SEL_CLASS] = { .type = NLA_U32 }, + [BATADV_ATTR_HOP_PENALTY] = { .type = NLA_U8 }, + [BATADV_ATTR_LOG_LEVEL] = { .type = NLA_U32 }, + [BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_NETWORK_CODING_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_ORIG_INTERVAL] = { .type = NLA_U32 }, + [BATADV_ATTR_ELP_INTERVAL] = { .type = NLA_U32 }, + [BATADV_ATTR_THROUGHPUT_OVERRIDE] = { .type = NLA_U32 }, }; /** @@ -122,20 +179,75 @@ batadv_netlink_get_ifindex(const struct nlmsghdr *nlh, int attrtype) } /** - * batadv_netlink_mesh_info_put() - fill in generic information about mesh - * interface - * @msg: netlink message to be sent back - * @soft_iface: interface for which the data should be taken + * batadv_netlink_mesh_fill_ap_isolation() - Add ap_isolation softif attribute + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information * - * Return: 0 on success, < 0 on error + * Return: 0 on success or negative error number in case of failure */ -static int -batadv_netlink_mesh_info_put(struct sk_buff *msg, struct net_device *soft_iface) +static int batadv_netlink_mesh_fill_ap_isolation(struct sk_buff *msg, + struct batadv_priv *bat_priv) +{ + struct batadv_softif_vlan *vlan; + u8 ap_isolation; + + vlan = batadv_softif_vlan_get(bat_priv, BATADV_NO_FLAGS); + if (!vlan) + return 0; + + ap_isolation = atomic_read(&vlan->ap_isolation); + batadv_softif_vlan_put(vlan); + + return nla_put_u8(msg, BATADV_ATTR_AP_ISOLATION_ENABLED, + !!ap_isolation); +} + +/** + * batadv_option_set_ap_isolation() - Set ap_isolation from genl msg + * @attr: parsed BATADV_ATTR_AP_ISOLATION_ENABLED attribute + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_mesh_ap_isolation(struct nlattr *attr, + struct batadv_priv *bat_priv) +{ + struct batadv_softif_vlan *vlan; + + vlan = batadv_softif_vlan_get(bat_priv, BATADV_NO_FLAGS); + if (!vlan) + return -ENOENT; + + atomic_set(&vlan->ap_isolation, !!nla_get_u8(attr)); + batadv_softif_vlan_put(vlan); + + return 0; +} + +/** + * batadv_netlink_mesh_fill() - Fill message with mesh attributes + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * @cmd: type of message to generate + * @portid: Port making netlink request + * @seq: sequence number for message + * @flags: Additional flags for message + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_mesh_fill(struct sk_buff *msg, + struct batadv_priv *bat_priv, + enum batadv_nl_commands cmd, + u32 portid, u32 seq, int flags) { - struct batadv_priv *bat_priv = netdev_priv(soft_iface); + struct net_device *soft_iface = bat_priv->soft_iface; struct batadv_hard_iface *primary_if = NULL; struct net_device *hard_iface; - int ret = -ENOBUFS; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, flags, cmd); + if (!hdr) + return -ENOBUFS; if (nla_put_string(msg, BATADV_ATTR_VERSION, BATADV_SOURCE_VERSION) || nla_put_string(msg, BATADV_ATTR_ALGO_NAME, @@ -146,16 +258,16 @@ batadv_netlink_mesh_info_put(struct sk_buff *msg, struct net_device *soft_iface) soft_iface->dev_addr) || nla_put_u8(msg, BATADV_ATTR_TT_TTVN, (u8)atomic_read(&bat_priv->tt.vn))) - goto out; + goto nla_put_failure; #ifdef CONFIG_BATMAN_ADV_BLA if (nla_put_u16(msg, BATADV_ATTR_BLA_CRC, ntohs(bat_priv->bla.claim_dest.group))) - goto out; + goto nla_put_failure; #endif if (batadv_mcast_mesh_info_put(msg, bat_priv)) - goto out; + goto nla_put_failure; primary_if = batadv_primary_if_get_selected(bat_priv); if (primary_if && primary_if->if_status == BATADV_IF_ACTIVE) { @@ -167,77 +279,345 @@ batadv_netlink_mesh_info_put(struct sk_buff *msg, struct net_device *soft_iface) hard_iface->name) || nla_put(msg, BATADV_ATTR_HARD_ADDRESS, ETH_ALEN, hard_iface->dev_addr)) - goto out; + goto nla_put_failure; } - ret = 0; + if (nla_put_u8(msg, BATADV_ATTR_AGGREGATED_OGMS_ENABLED, + !!atomic_read(&bat_priv->aggregated_ogms))) + goto nla_put_failure; + + if (batadv_netlink_mesh_fill_ap_isolation(msg, bat_priv)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_ISOLATION_MARK, + bat_priv->isolation_mark)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_ISOLATION_MASK, + bat_priv->isolation_mark_mask)) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_BONDING_ENABLED, + !!atomic_read(&bat_priv->bonding))) + goto nla_put_failure; + +#ifdef CONFIG_BATMAN_ADV_BLA + if (nla_put_u8(msg, BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED, + !!atomic_read(&bat_priv->bridge_loop_avoidance))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_BLA */ + +#ifdef CONFIG_BATMAN_ADV_DAT + if (nla_put_u8(msg, BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED, + !!atomic_read(&bat_priv->distributed_arp_table))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_DAT */ + + if (nla_put_u8(msg, BATADV_ATTR_FRAGMENTATION_ENABLED, + !!atomic_read(&bat_priv->fragmentation))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_GW_BANDWIDTH_DOWN, + atomic_read(&bat_priv->gw.bandwidth_down))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_GW_BANDWIDTH_UP, + atomic_read(&bat_priv->gw.bandwidth_up))) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_GW_MODE, + atomic_read(&bat_priv->gw.mode))) + goto nla_put_failure; + + if (bat_priv->algo_ops->gw.get_best_gw_node && + bat_priv->algo_ops->gw.is_eligible) { + /* GW selection class is not available if the routing algorithm + * in use does not implement the GW API + */ + if (nla_put_u32(msg, BATADV_ATTR_GW_SEL_CLASS, + atomic_read(&bat_priv->gw.sel_class))) + goto nla_put_failure; + } + + if (nla_put_u8(msg, BATADV_ATTR_HOP_PENALTY, + atomic_read(&bat_priv->hop_penalty))) + goto nla_put_failure; + +#ifdef CONFIG_BATMAN_ADV_DEBUG + if (nla_put_u32(msg, BATADV_ATTR_LOG_LEVEL, + atomic_read(&bat_priv->log_level))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_DEBUG */ + +#ifdef CONFIG_BATMAN_ADV_MCAST + if (nla_put_u8(msg, BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED, + !atomic_read(&bat_priv->multicast_mode))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_MCAST */ + +#ifdef CONFIG_BATMAN_ADV_NC + if (nla_put_u8(msg, BATADV_ATTR_NETWORK_CODING_ENABLED, + !!atomic_read(&bat_priv->network_coding))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_NC */ + + if (nla_put_u32(msg, BATADV_ATTR_ORIG_INTERVAL, + atomic_read(&bat_priv->orig_interval))) + goto nla_put_failure; - out: if (primary_if) batadv_hardif_put(primary_if); - return ret; + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + if (primary_if) + batadv_hardif_put(primary_if); + + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; } /** - * batadv_netlink_get_mesh_info() - handle incoming BATADV_CMD_GET_MESH_INFO - * netlink request - * @skb: received netlink message - * @info: receiver information + * batadv_netlink_notify_mesh() - send softif attributes to listener + * @bat_priv: the bat priv with all the soft interface information * * Return: 0 on success, < 0 on error */ -static int -batadv_netlink_get_mesh_info(struct sk_buff *skb, struct genl_info *info) +int batadv_netlink_notify_mesh(struct batadv_priv *bat_priv) { - struct net *net = genl_info_net(info); - struct net_device *soft_iface; - struct sk_buff *msg = NULL; - void *msg_head; - int ifindex; + struct sk_buff *msg; int ret; - if (!info->attrs[BATADV_ATTR_MESH_IFINDEX]) - return -EINVAL; - - ifindex = nla_get_u32(info->attrs[BATADV_ATTR_MESH_IFINDEX]); - if (!ifindex) - return -EINVAL; + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; - soft_iface = dev_get_by_index(net, ifindex); - if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { - ret = -ENODEV; - goto out; + ret = batadv_netlink_mesh_fill(msg, bat_priv, BATADV_CMD_SET_MESH, + 0, 0, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; } + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_CONFIG, GFP_KERNEL); + + return 0; +} + +/** + * batadv_netlink_get_mesh() - Get softif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_get_mesh(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg; + int ret; + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); - if (!msg) { - ret = -ENOMEM; - goto out; + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_mesh_fill(msg, bat_priv, BATADV_CMD_GET_MESH, + info->snd_portid, info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; } - msg_head = genlmsg_put(msg, info->snd_portid, info->snd_seq, - &batadv_netlink_family, 0, - BATADV_CMD_GET_MESH_INFO); - if (!msg_head) { - ret = -ENOBUFS; - goto out; + ret = genlmsg_reply(msg, info); + + return ret; +} + +/** + * batadv_netlink_set_mesh() - Set softif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_mesh(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct nlattr *attr; + + if (info->attrs[BATADV_ATTR_AGGREGATED_OGMS_ENABLED]) { + attr = info->attrs[BATADV_ATTR_AGGREGATED_OGMS_ENABLED]; + + atomic_set(&bat_priv->aggregated_ogms, !!nla_get_u8(attr)); } - ret = batadv_netlink_mesh_info_put(msg, soft_iface); + if (info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]) { + attr = info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]; - out: - if (soft_iface) - dev_put(soft_iface); + batadv_netlink_set_mesh_ap_isolation(attr, bat_priv); + } - if (ret) { - if (msg) - nlmsg_free(msg); - return ret; + if (info->attrs[BATADV_ATTR_ISOLATION_MARK]) { + attr = info->attrs[BATADV_ATTR_ISOLATION_MARK]; + + bat_priv->isolation_mark = nla_get_u32(attr); } - genlmsg_end(msg, msg_head); - return genlmsg_reply(msg, info); + if (info->attrs[BATADV_ATTR_ISOLATION_MASK]) { + attr = info->attrs[BATADV_ATTR_ISOLATION_MASK]; + + bat_priv->isolation_mark_mask = nla_get_u32(attr); + } + + if (info->attrs[BATADV_ATTR_BONDING_ENABLED]) { + attr = info->attrs[BATADV_ATTR_BONDING_ENABLED]; + + atomic_set(&bat_priv->bonding, !!nla_get_u8(attr)); + } + +#ifdef CONFIG_BATMAN_ADV_BLA + if (info->attrs[BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED]) { + attr = info->attrs[BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED]; + + atomic_set(&bat_priv->bridge_loop_avoidance, + !!nla_get_u8(attr)); + batadv_bla_status_update(bat_priv->soft_iface); + } +#endif /* CONFIG_BATMAN_ADV_BLA */ + +#ifdef CONFIG_BATMAN_ADV_DAT + if (info->attrs[BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED]) { + attr = info->attrs[BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED]; + + atomic_set(&bat_priv->distributed_arp_table, + !!nla_get_u8(attr)); + batadv_dat_status_update(bat_priv->soft_iface); + } +#endif /* CONFIG_BATMAN_ADV_DAT */ + + if (info->attrs[BATADV_ATTR_FRAGMENTATION_ENABLED]) { + attr = info->attrs[BATADV_ATTR_FRAGMENTATION_ENABLED]; + + atomic_set(&bat_priv->fragmentation, !!nla_get_u8(attr)); + batadv_update_min_mtu(bat_priv->soft_iface); + } + + if (info->attrs[BATADV_ATTR_GW_BANDWIDTH_DOWN]) { + attr = info->attrs[BATADV_ATTR_GW_BANDWIDTH_DOWN]; + + atomic_set(&bat_priv->gw.bandwidth_down, nla_get_u32(attr)); + batadv_gw_tvlv_container_update(bat_priv); + } + + if (info->attrs[BATADV_ATTR_GW_BANDWIDTH_UP]) { + attr = info->attrs[BATADV_ATTR_GW_BANDWIDTH_UP]; + + atomic_set(&bat_priv->gw.bandwidth_up, nla_get_u32(attr)); + batadv_gw_tvlv_container_update(bat_priv); + } + + if (info->attrs[BATADV_ATTR_GW_MODE]) { + u8 gw_mode; + + attr = info->attrs[BATADV_ATTR_GW_MODE]; + gw_mode = nla_get_u8(attr); + + if (gw_mode <= BATADV_GW_MODE_SERVER) { + /* Invoking batadv_gw_reselect() is not enough to really + * de-select the current GW. It will only instruct the + * gateway client code to perform a re-election the next + * time that this is needed. + * + * When gw client mode is being switched off the current + * GW must be de-selected explicitly otherwise no GW_ADD + * uevent is thrown on client mode re-activation. This + * is operation is performed in + * batadv_gw_check_client_stop(). + */ + batadv_gw_reselect(bat_priv); + + /* always call batadv_gw_check_client_stop() before + * changing the gateway state + */ + batadv_gw_check_client_stop(bat_priv); + atomic_set(&bat_priv->gw.mode, gw_mode); + batadv_gw_tvlv_container_update(bat_priv); + } + } + + if (info->attrs[BATADV_ATTR_GW_SEL_CLASS] && + bat_priv->algo_ops->gw.get_best_gw_node && + bat_priv->algo_ops->gw.is_eligible) { + /* setting the GW selection class is allowed only if the routing + * algorithm in use implements the GW API + */ + + u32 sel_class_max = 0xffffffffu; + u32 sel_class; + + attr = info->attrs[BATADV_ATTR_GW_SEL_CLASS]; + sel_class = nla_get_u32(attr); + + if (!bat_priv->algo_ops->gw.store_sel_class) + sel_class_max = BATADV_TQ_MAX_VALUE; + + if (sel_class >= 1 && sel_class <= sel_class_max) { + atomic_set(&bat_priv->gw.sel_class, sel_class); + batadv_gw_reselect(bat_priv); + } + } + + if (info->attrs[BATADV_ATTR_HOP_PENALTY]) { + attr = info->attrs[BATADV_ATTR_HOP_PENALTY]; + + atomic_set(&bat_priv->hop_penalty, nla_get_u8(attr)); + } + +#ifdef CONFIG_BATMAN_ADV_DEBUG + if (info->attrs[BATADV_ATTR_LOG_LEVEL]) { + attr = info->attrs[BATADV_ATTR_LOG_LEVEL]; + + atomic_set(&bat_priv->log_level, + nla_get_u32(attr) & BATADV_DBG_ALL); + } +#endif /* CONFIG_BATMAN_ADV_DEBUG */ + +#ifdef CONFIG_BATMAN_ADV_MCAST + if (info->attrs[BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED]) { + attr = info->attrs[BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED]; + + atomic_set(&bat_priv->multicast_mode, !nla_get_u8(attr)); + } +#endif /* CONFIG_BATMAN_ADV_MCAST */ + +#ifdef CONFIG_BATMAN_ADV_NC + if (info->attrs[BATADV_ATTR_NETWORK_CODING_ENABLED]) { + attr = info->attrs[BATADV_ATTR_NETWORK_CODING_ENABLED]; + + atomic_set(&bat_priv->network_coding, !!nla_get_u8(attr)); + batadv_nc_status_update(bat_priv->soft_iface); + } +#endif /* CONFIG_BATMAN_ADV_NC */ + + if (info->attrs[BATADV_ATTR_ORIG_INTERVAL]) { + u32 orig_interval; + + attr = info->attrs[BATADV_ATTR_ORIG_INTERVAL]; + orig_interval = nla_get_u32(attr); + + orig_interval = min_t(u32, orig_interval, INT_MAX); + orig_interval = max_t(u32, orig_interval, 2 * BATADV_JITTER); + + atomic_set(&bat_priv->orig_interval, orig_interval); + } + + batadv_netlink_notify_mesh(bat_priv); + + return 0; } /** @@ -329,40 +709,24 @@ err_genlmsg: static int batadv_netlink_tp_meter_start(struct sk_buff *skb, struct genl_info *info) { - struct net *net = genl_info_net(info); - struct net_device *soft_iface; - struct batadv_priv *bat_priv; + struct batadv_priv *bat_priv = info->user_ptr[0]; struct sk_buff *msg = NULL; u32 test_length; void *msg_head; - int ifindex; u32 cookie; u8 *dst; int ret; - if (!info->attrs[BATADV_ATTR_MESH_IFINDEX]) - return -EINVAL; - if (!info->attrs[BATADV_ATTR_ORIG_ADDRESS]) return -EINVAL; if (!info->attrs[BATADV_ATTR_TPMETER_TEST_TIME]) return -EINVAL; - ifindex = nla_get_u32(info->attrs[BATADV_ATTR_MESH_IFINDEX]); - if (!ifindex) - return -EINVAL; - dst = nla_data(info->attrs[BATADV_ATTR_ORIG_ADDRESS]); test_length = nla_get_u32(info->attrs[BATADV_ATTR_TPMETER_TEST_TIME]); - soft_iface = dev_get_by_index(net, ifindex); - if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { - ret = -ENODEV; - goto out; - } - msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!msg) { ret = -ENOMEM; @@ -377,15 +741,11 @@ batadv_netlink_tp_meter_start(struct sk_buff *skb, struct genl_info *info) goto out; } - bat_priv = netdev_priv(soft_iface); batadv_tp_start(bat_priv, dst, test_length, &cookie); ret = batadv_netlink_tp_meter_put(msg, cookie); out: - if (soft_iface) - dev_put(soft_iface); - if (ret) { if (msg) nlmsg_free(msg); @@ -406,65 +766,53 @@ batadv_netlink_tp_meter_start(struct sk_buff *skb, struct genl_info *info) static int batadv_netlink_tp_meter_cancel(struct sk_buff *skb, struct genl_info *info) { - struct net *net = genl_info_net(info); - struct net_device *soft_iface; - struct batadv_priv *bat_priv; - int ifindex; + struct batadv_priv *bat_priv = info->user_ptr[0]; u8 *dst; int ret = 0; - if (!info->attrs[BATADV_ATTR_MESH_IFINDEX]) - return -EINVAL; - if (!info->attrs[BATADV_ATTR_ORIG_ADDRESS]) return -EINVAL; - ifindex = nla_get_u32(info->attrs[BATADV_ATTR_MESH_IFINDEX]); - if (!ifindex) - return -EINVAL; - dst = nla_data(info->attrs[BATADV_ATTR_ORIG_ADDRESS]); - soft_iface = dev_get_by_index(net, ifindex); - if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { - ret = -ENODEV; - goto out; - } - - bat_priv = netdev_priv(soft_iface); batadv_tp_stop(bat_priv, dst, BATADV_TP_REASON_CANCEL); -out: - if (soft_iface) - dev_put(soft_iface); - return ret; } /** - * batadv_netlink_dump_hardif_entry() - Dump one hard interface into a message + * batadv_netlink_hardif_fill() - Fill message with hardif attributes * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * @hard_iface: hard interface which was modified + * @cmd: type of message to generate * @portid: Port making netlink request + * @seq: sequence number for message + * @flags: Additional flags for message * @cb: Control block containing additional options - * @hard_iface: Hard interface to dump * - * Return: error code, or 0 on success + * Return: 0 on success or negative error number in case of failure */ -static int -batadv_netlink_dump_hardif_entry(struct sk_buff *msg, u32 portid, - struct netlink_callback *cb, - struct batadv_hard_iface *hard_iface) +static int batadv_netlink_hardif_fill(struct sk_buff *msg, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface, + enum batadv_nl_commands cmd, + u32 portid, u32 seq, int flags, + struct netlink_callback *cb) { struct net_device *net_dev = hard_iface->net_dev; void *hdr; - hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, - &batadv_netlink_family, NLM_F_MULTI, - BATADV_CMD_GET_HARDIFS); + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, flags, cmd); if (!hdr) - return -EMSGSIZE; + return -ENOBUFS; + + if (cb) + genl_dump_check_consistent(cb, hdr); - genl_dump_check_consistent(cb, hdr); + if (nla_put_u32(msg, BATADV_ATTR_MESH_IFINDEX, + bat_priv->soft_iface->ifindex)) + goto nla_put_failure; if (nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, net_dev->ifindex) || @@ -479,27 +827,137 @@ batadv_netlink_dump_hardif_entry(struct sk_buff *msg, u32 portid, goto nla_put_failure; } +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + if (nla_put_u32(msg, BATADV_ATTR_ELP_INTERVAL, + atomic_read(&hard_iface->bat_v.elp_interval))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_THROUGHPUT_OVERRIDE, + atomic_read(&hard_iface->bat_v.throughput_override))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_BATMAN_V */ + genlmsg_end(msg, hdr); return 0; - nla_put_failure: +nla_put_failure: genlmsg_cancel(msg, hdr); return -EMSGSIZE; } /** - * batadv_netlink_dump_hardifs() - Dump all hard interface into a messages + * batadv_netlink_notify_hardif() - send hardif attributes to listener + * @bat_priv: the bat priv with all the soft interface information + * @hard_iface: hard interface which was modified + * + * Return: 0 on success, < 0 on error + */ +int batadv_netlink_notify_hardif(struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_hardif_fill(msg, bat_priv, hard_iface, + BATADV_CMD_SET_HARDIF, 0, 0, 0, NULL); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_CONFIG, GFP_KERNEL); + + return 0; +} + +/** + * batadv_netlink_get_hardif() - Get hardif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_get_hardif(struct sk_buff *skb, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_hardif_fill(msg, bat_priv, hard_iface, + BATADV_CMD_GET_HARDIF, + info->snd_portid, info->snd_seq, 0, + NULL); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + ret = genlmsg_reply(msg, info); + + return ret; +} + +/** + * batadv_netlink_set_hardif() - Set hardif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_hardif(struct sk_buff *skb, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + struct nlattr *attr; + + if (info->attrs[BATADV_ATTR_ELP_INTERVAL]) { + attr = info->attrs[BATADV_ATTR_ELP_INTERVAL]; + + atomic_set(&hard_iface->bat_v.elp_interval, nla_get_u32(attr)); + } + + if (info->attrs[BATADV_ATTR_THROUGHPUT_OVERRIDE]) { + attr = info->attrs[BATADV_ATTR_THROUGHPUT_OVERRIDE]; + + atomic_set(&hard_iface->bat_v.throughput_override, + nla_get_u32(attr)); + } +#endif /* CONFIG_BATMAN_ADV_BATMAN_V */ + + batadv_netlink_notify_hardif(bat_priv, hard_iface); + + return 0; +} + +/** + * batadv_netlink_dump_hardif() - Dump all hard interface into a messages * @msg: Netlink message to dump into * @cb: Parameters from query * * Return: error code, or length of reply message on success */ static int -batadv_netlink_dump_hardifs(struct sk_buff *msg, struct netlink_callback *cb) +batadv_netlink_dump_hardif(struct sk_buff *msg, struct netlink_callback *cb) { struct net *net = sock_net(cb->skb->sk); struct net_device *soft_iface; struct batadv_hard_iface *hard_iface; + struct batadv_priv *bat_priv; int ifindex; int portid = NETLINK_CB(cb->skb).portid; int skip = cb->args[0]; @@ -519,6 +977,8 @@ batadv_netlink_dump_hardifs(struct sk_buff *msg, struct netlink_callback *cb) return -ENODEV; } + bat_priv = netdev_priv(soft_iface); + rtnl_lock(); cb->seq = batadv_hardif_generation << 1 | 1; @@ -529,8 +989,10 @@ batadv_netlink_dump_hardifs(struct sk_buff *msg, struct netlink_callback *cb) if (i++ < skip) continue; - if (batadv_netlink_dump_hardif_entry(msg, portid, cb, - hard_iface)) { + if (batadv_netlink_hardif_fill(msg, bat_priv, hard_iface, + BATADV_CMD_GET_HARDIF, + portid, cb->nlh->nlmsg_seq, + NLM_F_MULTI, cb)) { i--; break; } @@ -545,24 +1007,361 @@ batadv_netlink_dump_hardifs(struct sk_buff *msg, struct netlink_callback *cb) return msg->len; } +/** + * batadv_netlink_vlan_fill() - Fill message with vlan attributes + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * @vlan: vlan which was modified + * @cmd: type of message to generate + * @portid: Port making netlink request + * @seq: sequence number for message + * @flags: Additional flags for message + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_vlan_fill(struct sk_buff *msg, + struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan, + enum batadv_nl_commands cmd, + u32 portid, u32 seq, int flags) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, BATADV_ATTR_MESH_IFINDEX, + bat_priv->soft_iface->ifindex)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_VLANID, vlan->vid & VLAN_VID_MASK)) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_AP_ISOLATION_ENABLED, + !!atomic_read(&vlan->ap_isolation))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_netlink_notify_vlan() - send vlan attributes to listener + * @bat_priv: the bat priv with all the soft interface information + * @vlan: vlan which was modified + * + * Return: 0 on success, < 0 on error + */ +int batadv_netlink_notify_vlan(struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_vlan_fill(msg, bat_priv, vlan, + BATADV_CMD_SET_VLAN, 0, 0, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_CONFIG, GFP_KERNEL); + + return 0; +} + +/** + * batadv_netlink_get_vlan() - Get vlan attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_get_vlan(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_softif_vlan *vlan = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_vlan_fill(msg, bat_priv, vlan, BATADV_CMD_GET_VLAN, + info->snd_portid, info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + ret = genlmsg_reply(msg, info); + + return ret; +} + +/** + * batadv_netlink_set_vlan() - Get vlan attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_vlan(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_softif_vlan *vlan = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct nlattr *attr; + + if (info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]) { + attr = info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]; + + atomic_set(&vlan->ap_isolation, !!nla_get_u8(attr)); + } + + batadv_netlink_notify_vlan(bat_priv, vlan); + + return 0; +} + +/** + * batadv_get_softif_from_info() - Retrieve soft interface from genl attributes + * @net: the applicable net namespace + * @info: receiver information + * + * Return: Pointer to soft interface (with increased refcnt) on success, error + * pointer on error + */ +static struct net_device * +batadv_get_softif_from_info(struct net *net, struct genl_info *info) +{ + struct net_device *soft_iface; + int ifindex; + + if (!info->attrs[BATADV_ATTR_MESH_IFINDEX]) + return ERR_PTR(-EINVAL); + + ifindex = nla_get_u32(info->attrs[BATADV_ATTR_MESH_IFINDEX]); + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface) + return ERR_PTR(-ENODEV); + + if (!batadv_softif_is_valid(soft_iface)) + goto err_put_softif; + + return soft_iface; + +err_put_softif: + dev_put(soft_iface); + + return ERR_PTR(-EINVAL); +} + +/** + * batadv_get_hardif_from_info() - Retrieve hardif from genl attributes + * @bat_priv: the bat priv with all the soft interface information + * @net: the applicable net namespace + * @info: receiver information + * + * Return: Pointer to hard interface (with increased refcnt) on success, error + * pointer on error + */ +static struct batadv_hard_iface * +batadv_get_hardif_from_info(struct batadv_priv *bat_priv, struct net *net, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface; + struct net_device *hard_dev; + unsigned int hardif_index; + + if (!info->attrs[BATADV_ATTR_HARD_IFINDEX]) + return ERR_PTR(-EINVAL); + + hardif_index = nla_get_u32(info->attrs[BATADV_ATTR_HARD_IFINDEX]); + + hard_dev = dev_get_by_index(net, hardif_index); + if (!hard_dev) + return ERR_PTR(-ENODEV); + + hard_iface = batadv_hardif_get_by_netdev(hard_dev); + if (!hard_iface) + goto err_put_harddev; + + if (hard_iface->soft_iface != bat_priv->soft_iface) + goto err_put_hardif; + + /* hard_dev is referenced by hard_iface and not needed here */ + dev_put(hard_dev); + + return hard_iface; + +err_put_hardif: + batadv_hardif_put(hard_iface); +err_put_harddev: + dev_put(hard_dev); + + return ERR_PTR(-EINVAL); +} + +/** + * batadv_get_vlan_from_info() - Retrieve vlan from genl attributes + * @bat_priv: the bat priv with all the soft interface information + * @net: the applicable net namespace + * @info: receiver information + * + * Return: Pointer to vlan on success (with increased refcnt), error pointer + * on error + */ +static struct batadv_softif_vlan * +batadv_get_vlan_from_info(struct batadv_priv *bat_priv, struct net *net, + struct genl_info *info) +{ + struct batadv_softif_vlan *vlan; + u16 vid; + + if (!info->attrs[BATADV_ATTR_VLANID]) + return ERR_PTR(-EINVAL); + + vid = nla_get_u16(info->attrs[BATADV_ATTR_VLANID]); + + vlan = batadv_softif_vlan_get(bat_priv, vid | BATADV_VLAN_HAS_TAG); + if (!vlan) + return ERR_PTR(-ENOENT); + + return vlan; +} + +/** + * batadv_pre_doit() - Prepare batman-adv genl doit request + * @ops: requested netlink operation + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct batadv_hard_iface *hard_iface; + struct batadv_priv *bat_priv = NULL; + struct batadv_softif_vlan *vlan; + struct net_device *soft_iface; + u8 user_ptr1_flags; + u8 mesh_dep_flags; + int ret; + + user_ptr1_flags = BATADV_FLAG_NEED_HARDIF | BATADV_FLAG_NEED_VLAN; + if (WARN_ON(hweight8(ops->internal_flags & user_ptr1_flags) > 1)) + return -EINVAL; + + mesh_dep_flags = BATADV_FLAG_NEED_HARDIF | BATADV_FLAG_NEED_VLAN; + if (WARN_ON((ops->internal_flags & mesh_dep_flags) && + (~ops->internal_flags & BATADV_FLAG_NEED_MESH))) + return -EINVAL; + + if (ops->internal_flags & BATADV_FLAG_NEED_MESH) { + soft_iface = batadv_get_softif_from_info(net, info); + if (IS_ERR(soft_iface)) + return PTR_ERR(soft_iface); + + bat_priv = netdev_priv(soft_iface); + info->user_ptr[0] = bat_priv; + } + + if (ops->internal_flags & BATADV_FLAG_NEED_HARDIF) { + hard_iface = batadv_get_hardif_from_info(bat_priv, net, info); + if (IS_ERR(hard_iface)) { + ret = PTR_ERR(hard_iface); + goto err_put_softif; + } + + info->user_ptr[1] = hard_iface; + } + + if (ops->internal_flags & BATADV_FLAG_NEED_VLAN) { + vlan = batadv_get_vlan_from_info(bat_priv, net, info); + if (IS_ERR(vlan)) { + ret = PTR_ERR(vlan); + goto err_put_softif; + } + + info->user_ptr[1] = vlan; + } + + return 0; + +err_put_softif: + if (bat_priv) + dev_put(bat_priv->soft_iface); + + return ret; +} + +/** + * batadv_post_doit() - End batman-adv genl doit request + * @ops: requested netlink operation + * @skb: Netlink message with request data + * @info: receiver information + */ +static void batadv_post_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface; + struct batadv_softif_vlan *vlan; + struct batadv_priv *bat_priv; + + if (ops->internal_flags & BATADV_FLAG_NEED_HARDIF && + info->user_ptr[1]) { + hard_iface = info->user_ptr[1]; + + batadv_hardif_put(hard_iface); + } + + if (ops->internal_flags & BATADV_FLAG_NEED_VLAN && info->user_ptr[1]) { + vlan = info->user_ptr[1]; + batadv_softif_vlan_put(vlan); + } + + if (ops->internal_flags & BATADV_FLAG_NEED_MESH && info->user_ptr[0]) { + bat_priv = info->user_ptr[0]; + dev_put(bat_priv->soft_iface); + } +} + static const struct genl_ops batadv_netlink_ops[] = { { - .cmd = BATADV_CMD_GET_MESH_INFO, - .flags = GENL_ADMIN_PERM, + .cmd = BATADV_CMD_GET_MESH, + /* can be retrieved by unprivileged users */ .policy = batadv_netlink_policy, - .doit = batadv_netlink_get_mesh_info, + .doit = batadv_netlink_get_mesh, + .internal_flags = BATADV_FLAG_NEED_MESH, }, { .cmd = BATADV_CMD_TP_METER, .flags = GENL_ADMIN_PERM, .policy = batadv_netlink_policy, .doit = batadv_netlink_tp_meter_start, + .internal_flags = BATADV_FLAG_NEED_MESH, }, { .cmd = BATADV_CMD_TP_METER_CANCEL, .flags = GENL_ADMIN_PERM, .policy = batadv_netlink_policy, .doit = batadv_netlink_tp_meter_cancel, + .internal_flags = BATADV_FLAG_NEED_MESH, }, { .cmd = BATADV_CMD_GET_ROUTING_ALGOS, @@ -571,10 +1370,13 @@ static const struct genl_ops batadv_netlink_ops[] = { .dumpit = batadv_algo_dump, }, { - .cmd = BATADV_CMD_GET_HARDIFS, - .flags = GENL_ADMIN_PERM, + .cmd = BATADV_CMD_GET_HARDIF, + /* can be retrieved by unprivileged users */ .policy = batadv_netlink_policy, - .dumpit = batadv_netlink_dump_hardifs, + .dumpit = batadv_netlink_dump_hardif, + .doit = batadv_netlink_get_hardif, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_HARDIF, }, { .cmd = BATADV_CMD_GET_TRANSTABLE_LOCAL, @@ -630,7 +1432,37 @@ static const struct genl_ops batadv_netlink_ops[] = { .policy = batadv_netlink_policy, .dumpit = batadv_mcast_flags_dump, }, - + { + .cmd = BATADV_CMD_SET_MESH, + .flags = GENL_ADMIN_PERM, + .policy = batadv_netlink_policy, + .doit = batadv_netlink_set_mesh, + .internal_flags = BATADV_FLAG_NEED_MESH, + }, + { + .cmd = BATADV_CMD_SET_HARDIF, + .flags = GENL_ADMIN_PERM, + .policy = batadv_netlink_policy, + .doit = batadv_netlink_set_hardif, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_HARDIF, + }, + { + .cmd = BATADV_CMD_GET_VLAN, + /* can be retrieved by unprivileged users */ + .policy = batadv_netlink_policy, + .doit = batadv_netlink_get_vlan, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_VLAN, + }, + { + .cmd = BATADV_CMD_SET_VLAN, + .flags = GENL_ADMIN_PERM, + .policy = batadv_netlink_policy, + .doit = batadv_netlink_set_vlan, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_VLAN, + }, }; struct genl_family batadv_netlink_family __ro_after_init = { @@ -639,6 +1471,8 @@ struct genl_family batadv_netlink_family __ro_after_init = { .version = 1, .maxattr = BATADV_ATTR_MAX, .netnsok = true, + .pre_doit = batadv_pre_doit, + .post_doit = batadv_post_doit, .module = THIS_MODULE, .ops = batadv_netlink_ops, .n_ops = ARRAY_SIZE(batadv_netlink_ops), diff --git a/net/batman-adv/netlink.h b/net/batman-adv/netlink.h index 571d9a5ae7aa..7273368544fc 100644 --- a/net/batman-adv/netlink.h +++ b/net/batman-adv/netlink.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2016-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2016-2019 B.A.T.M.A.N. contributors: * * Matthias Schiffer * @@ -34,6 +34,12 @@ int batadv_netlink_tpmeter_notify(struct batadv_priv *bat_priv, const u8 *dst, u8 result, u32 test_time, u64 total_bytes, u32 cookie); +int batadv_netlink_notify_mesh(struct batadv_priv *bat_priv); +int batadv_netlink_notify_hardif(struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface); +int batadv_netlink_notify_vlan(struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan); + extern struct genl_family batadv_netlink_family; #endif /* _NET_BATMAN_ADV_NETLINK_H_ */ diff --git a/net/batman-adv/network-coding.c b/net/batman-adv/network-coding.c index 34caf129a9bf..278762bd94c6 100644 --- a/net/batman-adv/network-coding.c +++ b/net/batman-adv/network-coding.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2012-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2012-2019 B.A.T.M.A.N. contributors: * * Martin Hundebøll, Jeppe Ledet-Pedersen * diff --git a/net/batman-adv/network-coding.h b/net/batman-adv/network-coding.h index 65c346812bc1..96ef0a511fc7 100644 --- a/net/batman-adv/network-coding.h +++ b/net/batman-adv/network-coding.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2012-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2012-2019 B.A.T.M.A.N. contributors: * * Martin Hundebøll, Jeppe Ledet-Pedersen * diff --git a/net/batman-adv/originator.c b/net/batman-adv/originator.c index 56a981af5c92..e5cdf89ef63c 100644 --- a/net/batman-adv/originator.c +++ b/net/batman-adv/originator.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2009-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2009-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/originator.h b/net/batman-adv/originator.h index a8b4c7b667ec..dca1e4a34ec6 100644 --- a/net/batman-adv/originator.h +++ b/net/batman-adv/originator.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/routing.c b/net/batman-adv/routing.c index cc3ed93a6d51..cae0e5dd0768 100644 --- a/net/batman-adv/routing.c +++ b/net/batman-adv/routing.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * @@ -1043,6 +1043,8 @@ int batadv_recv_unicast_packet(struct sk_buff *skb, hdr_size)) goto rx_success; + batadv_dat_snoop_incoming_dhcp_ack(bat_priv, skb, hdr_size); + batadv_interface_rx(recv_if->soft_iface, skb, hdr_size, orig_node); @@ -1278,6 +1280,8 @@ int batadv_recv_bcast_packet(struct sk_buff *skb, if (batadv_dat_snoop_incoming_arp_reply(bat_priv, skb, hdr_size)) goto rx_success; + batadv_dat_snoop_incoming_dhcp_ack(bat_priv, skb, hdr_size); + /* broadcast for me */ batadv_interface_rx(recv_if->soft_iface, skb, hdr_size, orig_node); diff --git a/net/batman-adv/routing.h b/net/batman-adv/routing.h index db54c2d9b8bf..0102d69d345c 100644 --- a/net/batman-adv/routing.h +++ b/net/batman-adv/routing.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/send.c b/net/batman-adv/send.c index 4a35f5c2f52b..66a8b3e44501 100644 --- a/net/batman-adv/send.c +++ b/net/batman-adv/send.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/send.h b/net/batman-adv/send.h index 64cce07b8fe6..1f6132922e60 100644 --- a/net/batman-adv/send.h +++ b/net/batman-adv/send.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/soft-interface.c b/net/batman-adv/soft-interface.c index 5db5a0a4c959..2e367230376b 100644 --- a/net/batman-adv/soft-interface.c +++ b/net/batman-adv/soft-interface.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * @@ -50,13 +50,13 @@ #include <linux/string.h> #include <linux/types.h> #include <uapi/linux/batadv_packet.h> +#include <uapi/linux/batman_adv.h> #include "bat_algo.h" #include "bridge_loop_avoidance.h" #include "debugfs.h" #include "distributed-arp-table.h" #include "gateway_client.h" -#include "gateway_common.h" #include "hard-interface.h" #include "multicast.h" #include "network-coding.h" @@ -212,6 +212,7 @@ static netdev_tx_t batadv_interface_tx(struct sk_buff *skb, enum batadv_forw_mode forw_mode; struct batadv_orig_node *mcast_single_orig = NULL; int network_offset = ETH_HLEN; + __be16 proto; if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) goto dropped; @@ -221,14 +222,21 @@ static netdev_tx_t batadv_interface_tx(struct sk_buff *skb, netif_trans_update(soft_iface); vid = batadv_get_vid(skb, 0); + + skb_reset_mac_header(skb); ethhdr = eth_hdr(skb); - switch (ntohs(ethhdr->h_proto)) { + proto = ethhdr->h_proto; + + switch (ntohs(proto)) { case ETH_P_8021Q: + if (!pskb_may_pull(skb, sizeof(*vhdr))) + goto dropped; vhdr = vlan_eth_hdr(skb); + proto = vhdr->h_vlan_encapsulated_proto; /* drop batman-in-batman packets to prevent loops */ - if (vhdr->h_vlan_encapsulated_proto != htons(ETH_P_BATMAN)) { + if (proto != htons(ETH_P_BATMAN)) { network_offset += VLAN_HLEN; break; } @@ -256,6 +264,9 @@ static netdev_tx_t batadv_interface_tx(struct sk_buff *skb, goto dropped; } + /* Snoop address candidates from DHCPACKs for early DAT filling */ + batadv_dat_snoop_outgoing_dhcp_ack(bat_priv, skb, proto, vid); + /* don't accept stp packets. STP does not help in meshes. * better use the bridge loop avoidance ... * diff --git a/net/batman-adv/soft-interface.h b/net/batman-adv/soft-interface.h index daf87f07fadd..538bb661878c 100644 --- a/net/batman-adv/soft-interface.h +++ b/net/batman-adv/soft-interface.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/sysfs.c b/net/batman-adv/sysfs.c index 09427fc6494a..0b4b3fb778a6 100644 --- a/net/batman-adv/sysfs.c +++ b/net/batman-adv/sysfs.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * @@ -40,6 +40,7 @@ #include <linux/stringify.h> #include <linux/workqueue.h> #include <uapi/linux/batadv_packet.h> +#include <uapi/linux/batman_adv.h> #include "bridge_loop_avoidance.h" #include "distributed-arp-table.h" @@ -47,6 +48,7 @@ #include "gateway_common.h" #include "hard-interface.h" #include "log.h" +#include "netlink.h" #include "network-coding.h" #include "soft-interface.h" @@ -153,9 +155,14 @@ ssize_t batadv_store_##_name(struct kobject *kobj, \ { \ struct net_device *net_dev = batadv_kobj_to_netdev(kobj); \ struct batadv_priv *bat_priv = netdev_priv(net_dev); \ + ssize_t length; \ + \ + length = __batadv_store_bool_attr(buff, count, _post_func, attr,\ + &bat_priv->_name, net_dev); \ \ - return __batadv_store_bool_attr(buff, count, _post_func, attr, \ - &bat_priv->_name, net_dev); \ + batadv_netlink_notify_mesh(bat_priv); \ + \ + return length; \ } #define BATADV_ATTR_SIF_SHOW_BOOL(_name) \ @@ -185,11 +192,16 @@ ssize_t batadv_store_##_name(struct kobject *kobj, \ { \ struct net_device *net_dev = batadv_kobj_to_netdev(kobj); \ struct batadv_priv *bat_priv = netdev_priv(net_dev); \ + ssize_t length; \ \ - return __batadv_store_uint_attr(buff, count, _min, _max, \ - _post_func, attr, \ - &bat_priv->_var, net_dev, \ - NULL); \ + length = __batadv_store_uint_attr(buff, count, _min, _max, \ + _post_func, attr, \ + &bat_priv->_var, net_dev, \ + NULL); \ + \ + batadv_netlink_notify_mesh(bat_priv); \ + \ + return length; \ } #define BATADV_ATTR_SIF_SHOW_UINT(_name, _var) \ @@ -222,6 +234,11 @@ ssize_t batadv_store_vlan_##_name(struct kobject *kobj, \ attr, &vlan->_name, \ bat_priv->soft_iface); \ \ + if (vlan->vid) \ + batadv_netlink_notify_vlan(bat_priv, vlan); \ + else \ + batadv_netlink_notify_mesh(bat_priv); \ + \ batadv_softif_vlan_put(vlan); \ return res; \ } @@ -255,6 +272,7 @@ ssize_t batadv_store_##_name(struct kobject *kobj, \ { \ struct net_device *net_dev = batadv_kobj_to_netdev(kobj); \ struct batadv_hard_iface *hard_iface; \ + struct batadv_priv *bat_priv; \ ssize_t length; \ \ hard_iface = batadv_hardif_get_by_netdev(net_dev); \ @@ -267,6 +285,11 @@ ssize_t batadv_store_##_name(struct kobject *kobj, \ hard_iface->soft_iface, \ net_dev); \ \ + if (hard_iface->soft_iface) { \ + bat_priv = netdev_priv(hard_iface->soft_iface); \ + batadv_netlink_notify_hardif(bat_priv, hard_iface); \ + } \ + \ batadv_hardif_put(hard_iface); \ return length; \ } @@ -536,6 +559,9 @@ static ssize_t batadv_store_gw_mode(struct kobject *kobj, batadv_gw_check_client_stop(bat_priv); atomic_set(&bat_priv->gw.mode, (unsigned int)gw_mode_tmp); batadv_gw_tvlv_container_update(bat_priv); + + batadv_netlink_notify_mesh(bat_priv); + return count; } @@ -562,6 +588,7 @@ static ssize_t batadv_store_gw_sel_class(struct kobject *kobj, size_t count) { struct batadv_priv *bat_priv = batadv_kobj_to_batpriv(kobj); + ssize_t length; /* setting the GW selection class is allowed only if the routing * algorithm in use implements the GW API @@ -577,10 +604,14 @@ static ssize_t batadv_store_gw_sel_class(struct kobject *kobj, return bat_priv->algo_ops->gw.store_sel_class(bat_priv, buff, count); - return __batadv_store_uint_attr(buff, count, 1, BATADV_TQ_MAX_VALUE, - batadv_post_gw_reselect, attr, - &bat_priv->gw.sel_class, - bat_priv->soft_iface, NULL); + length = __batadv_store_uint_attr(buff, count, 1, BATADV_TQ_MAX_VALUE, + batadv_post_gw_reselect, attr, + &bat_priv->gw.sel_class, + bat_priv->soft_iface, NULL); + + batadv_netlink_notify_mesh(bat_priv); + + return length; } static ssize_t batadv_show_gw_bwidth(struct kobject *kobj, @@ -600,12 +631,18 @@ static ssize_t batadv_store_gw_bwidth(struct kobject *kobj, struct attribute *attr, char *buff, size_t count) { + struct batadv_priv *bat_priv = batadv_kobj_to_batpriv(kobj); struct net_device *net_dev = batadv_kobj_to_netdev(kobj); + ssize_t length; if (buff[count - 1] == '\n') buff[count - 1] = '\0'; - return batadv_gw_bandwidth_set(net_dev, buff, count); + length = batadv_gw_bandwidth_set(net_dev, buff, count); + + batadv_netlink_notify_mesh(bat_priv); + + return length; } /** @@ -673,6 +710,8 @@ static ssize_t batadv_store_isolation_mark(struct kobject *kobj, "New skb mark for extended isolation: %#.8x/%#.8x\n", bat_priv->isolation_mark, bat_priv->isolation_mark_mask); + batadv_netlink_notify_mesh(bat_priv); + return count; } @@ -1077,6 +1116,7 @@ static ssize_t batadv_store_throughput_override(struct kobject *kobj, struct attribute *attr, char *buff, size_t count) { + struct batadv_priv *bat_priv = batadv_kobj_to_batpriv(kobj); struct net_device *net_dev = batadv_kobj_to_netdev(kobj); struct batadv_hard_iface *hard_iface; u32 tp_override; @@ -1107,6 +1147,8 @@ static ssize_t batadv_store_throughput_override(struct kobject *kobj, atomic_set(&hard_iface->bat_v.throughput_override, tp_override); + batadv_netlink_notify_hardif(bat_priv, hard_iface); + out: batadv_hardif_put(hard_iface); return count; diff --git a/net/batman-adv/sysfs.h b/net/batman-adv/sysfs.h index c1e3fb69952d..705ffbe763f4 100644 --- a/net/batman-adv/sysfs.h +++ b/net/batman-adv/sysfs.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Marek Lindner * diff --git a/net/batman-adv/tp_meter.c b/net/batman-adv/tp_meter.c index 11520de96ccb..500109bbd551 100644 --- a/net/batman-adv/tp_meter.c +++ b/net/batman-adv/tp_meter.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2012-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2012-2019 B.A.T.M.A.N. contributors: * * Edo Monticelli, Antonio Quartulli * diff --git a/net/batman-adv/tp_meter.h b/net/batman-adv/tp_meter.h index 68e600974759..6b4d0f733896 100644 --- a/net/batman-adv/tp_meter.h +++ b/net/batman-adv/tp_meter.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2012-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2012-2019 B.A.T.M.A.N. contributors: * * Edo Monticelli, Antonio Quartulli * diff --git a/net/batman-adv/trace.c b/net/batman-adv/trace.c index 8e1024217cff..f77c917ed20d 100644 --- a/net/batman-adv/trace.c +++ b/net/batman-adv/trace.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Sven Eckelmann * diff --git a/net/batman-adv/trace.h b/net/batman-adv/trace.h index 104784be94d7..5e5579051400 100644 --- a/net/batman-adv/trace.h +++ b/net/batman-adv/trace.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2010-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2010-2019 B.A.T.M.A.N. contributors: * * Sven Eckelmann * diff --git a/net/batman-adv/translation-table.c b/net/batman-adv/translation-table.c index 8dcd4968cde7..f73d79139ae7 100644 --- a/net/batman-adv/translation-table.c +++ b/net/batman-adv/translation-table.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich, Antonio Quartulli * diff --git a/net/batman-adv/translation-table.h b/net/batman-adv/translation-table.h index 01b6c8eafaf9..61bca75e5911 100644 --- a/net/batman-adv/translation-table.h +++ b/net/batman-adv/translation-table.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich, Antonio Quartulli * diff --git a/net/batman-adv/tvlv.c b/net/batman-adv/tvlv.c index 40e69c9346d2..7e947b01919d 100644 --- a/net/batman-adv/tvlv.c +++ b/net/batman-adv/tvlv.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/tvlv.h b/net/batman-adv/tvlv.h index ef5867f49824..c0f033b1acb8 100644 --- a/net/batman-adv/tvlv.h +++ b/net/batman-adv/tvlv.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/batman-adv/types.h b/net/batman-adv/types.h index cbe17da36fcb..a21b34ed6548 100644 --- a/net/batman-adv/types.h +++ b/net/batman-adv/types.h @@ -1,5 +1,5 @@ /* SPDX-License-Identifier: GPL-2.0 */ -/* Copyright (C) 2007-2018 B.A.T.M.A.N. contributors: +/* Copyright (C) 2007-2019 B.A.T.M.A.N. contributors: * * Marek Lindner, Simon Wunderlich * diff --git a/net/bluetooth/6lowpan.c b/net/bluetooth/6lowpan.c index 9d79c7de234a..a7cd23f00bde 100644 --- a/net/bluetooth/6lowpan.c +++ b/net/bluetooth/6lowpan.c @@ -1108,8 +1108,8 @@ static int lowpan_enable_get(void *data, u64 *val) return 0; } -DEFINE_SIMPLE_ATTRIBUTE(lowpan_enable_fops, lowpan_enable_get, - lowpan_enable_set, "%llu\n"); +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_enable_fops, lowpan_enable_get, + lowpan_enable_set, "%llu\n"); static ssize_t lowpan_control_write(struct file *fp, const char __user *user_buffer, @@ -1278,9 +1278,10 @@ static struct notifier_block bt_6lowpan_dev_notifier = { static int __init bt_6lowpan_init(void) { - lowpan_enable_debugfs = debugfs_create_file("6lowpan_enable", 0644, - bt_debugfs, NULL, - &lowpan_enable_fops); + lowpan_enable_debugfs = debugfs_create_file_unsafe("6lowpan_enable", + 0644, bt_debugfs, + NULL, + &lowpan_enable_fops); lowpan_control_debugfs = debugfs_create_file("6lowpan_control", 0644, bt_debugfs, NULL, &lowpan_control_fops); diff --git a/net/bluetooth/a2mp.c b/net/bluetooth/a2mp.c index 58fc6333d412..5f918ea18b5a 100644 --- a/net/bluetooth/a2mp.c +++ b/net/bluetooth/a2mp.c @@ -174,7 +174,7 @@ static int a2mp_discover_req(struct amp_mgr *mgr, struct sk_buff *skb, num_ctrl++; } - len = num_ctrl * sizeof(struct a2mp_cl) + sizeof(*rsp); + len = struct_size(rsp, cl, num_ctrl); rsp = kmalloc(len, GFP_ATOMIC); if (!rsp) { read_unlock(&hci_dev_list_lock); diff --git a/net/bluetooth/af_bluetooth.c b/net/bluetooth/af_bluetooth.c index deacc52d7ff1..8d12198eaa94 100644 --- a/net/bluetooth/af_bluetooth.c +++ b/net/bluetooth/af_bluetooth.c @@ -154,15 +154,25 @@ void bt_sock_unlink(struct bt_sock_list *l, struct sock *sk) } EXPORT_SYMBOL(bt_sock_unlink); -void bt_accept_enqueue(struct sock *parent, struct sock *sk) +void bt_accept_enqueue(struct sock *parent, struct sock *sk, bool bh) { BT_DBG("parent %p, sk %p", parent, sk); sock_hold(sk); - lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (bh) + bh_lock_sock_nested(sk); + else + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + list_add_tail(&bt_sk(sk)->accept_q, &bt_sk(parent)->accept_q); bt_sk(sk)->parent = parent; - release_sock(sk); + + if (bh) + bh_unlock_sock(sk); + else + release_sock(sk); + parent->sk_ack_backlog++; } EXPORT_SYMBOL(bt_accept_enqueue); diff --git a/net/bluetooth/hci_core.c b/net/bluetooth/hci_core.c index 7352fe85674b..d6b2540ba7f8 100644 --- a/net/bluetooth/hci_core.c +++ b/net/bluetooth/hci_core.c @@ -30,6 +30,7 @@ #include <linux/rfkill.h> #include <linux/debugfs.h> #include <linux/crypto.h> +#include <linux/property.h> #include <asm/unaligned.h> #include <net/bluetooth/bluetooth.h> @@ -1355,6 +1356,32 @@ done: return err; } +/** + * hci_dev_get_bd_addr_from_property - Get the Bluetooth Device Address + * (BD_ADDR) for a HCI device from + * a firmware node property. + * @hdev: The HCI device + * + * Search the firmware node for 'local-bd-address'. + * + * All-zero BD addresses are rejected, because those could be properties + * that exist in the firmware tables, but were not updated by the firmware. For + * example, the DTS could define 'local-bd-address', with zero BD addresses. + */ +static void hci_dev_get_bd_addr_from_property(struct hci_dev *hdev) +{ + struct fwnode_handle *fwnode = dev_fwnode(hdev->dev.parent); + bdaddr_t ba; + int ret; + + ret = fwnode_property_read_u8_array(fwnode, "local-bd-address", + (u8 *)&ba, sizeof(ba)); + if (ret < 0 || !bacmp(&ba, BDADDR_ANY)) + return; + + bacpy(&hdev->public_addr, &ba); +} + static int hci_dev_do_open(struct hci_dev *hdev) { int ret = 0; @@ -1422,6 +1449,22 @@ static int hci_dev_do_open(struct hci_dev *hdev) if (hdev->setup) ret = hdev->setup(hdev); + if (ret) + goto setup_failed; + + if (test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks)) { + if (!bacmp(&hdev->public_addr, BDADDR_ANY)) + hci_dev_get_bd_addr_from_property(hdev); + + if (bacmp(&hdev->public_addr, BDADDR_ANY) && + hdev->set_bdaddr) + ret = hdev->set_bdaddr(hdev, + &hdev->public_addr); + else + ret = -EADDRNOTAVAIL; + } + +setup_failed: /* The transport driver can set these quirks before * creating the HCI device or in its setup callback. * @@ -2578,6 +2621,9 @@ static void hci_cmd_timeout(struct work_struct *work) bt_dev_err(hdev, "command tx timeout"); } + if (hdev->cmd_timeout) + hdev->cmd_timeout(hdev); + atomic_set(&hdev->cmd_cnt, 1); queue_work(hdev->workqueue, &hdev->cmd_work); } @@ -3401,7 +3447,7 @@ EXPORT_SYMBOL(hci_resume_dev); /* Reset HCI device */ int hci_reset_dev(struct hci_dev *hdev) { - const u8 hw_err[] = { HCI_EV_HARDWARE_ERROR, 0x01, 0x00 }; + static const u8 hw_err[] = { HCI_EV_HARDWARE_ERROR, 0x01, 0x00 }; struct sk_buff *skb; skb = bt_skb_alloc(3, GFP_ATOMIC); diff --git a/net/bluetooth/hci_event.c b/net/bluetooth/hci_event.c index ac2826ce162b..609fd6871c5a 100644 --- a/net/bluetooth/hci_event.c +++ b/net/bluetooth/hci_event.c @@ -3556,8 +3556,8 @@ static void hci_num_comp_pkts_evt(struct hci_dev *hdev, struct sk_buff *skb) return; } - if (skb->len < sizeof(*ev) || skb->len < sizeof(*ev) + - ev->num_hndl * sizeof(struct hci_comp_pkts_info)) { + if (skb->len < sizeof(*ev) || + skb->len < struct_size(ev, handles, ev->num_hndl)) { BT_DBG("%s bad parameters", hdev->name); return; } @@ -3644,8 +3644,8 @@ static void hci_num_comp_blocks_evt(struct hci_dev *hdev, struct sk_buff *skb) return; } - if (skb->len < sizeof(*ev) || skb->len < sizeof(*ev) + - ev->num_hndl * sizeof(struct hci_comp_blocks_info)) { + if (skb->len < sizeof(*ev) || + skb->len < struct_size(ev, handles, ev->num_hndl)) { BT_DBG("%s bad parameters", hdev->name); return; } diff --git a/net/bluetooth/hci_sock.c b/net/bluetooth/hci_sock.c index 1506e1632394..d32077b28433 100644 --- a/net/bluetooth/hci_sock.c +++ b/net/bluetooth/hci_sock.c @@ -831,8 +831,6 @@ static int hci_sock_release(struct socket *sock) if (!sk) return 0; - hdev = hci_pi(sk)->hdev; - switch (hci_pi(sk)->channel) { case HCI_CHANNEL_MONITOR: atomic_dec(&monitor_promisc); @@ -854,6 +852,7 @@ static int hci_sock_release(struct socket *sock) bt_sock_unlink(&hci_sk_list, sk); + hdev = hci_pi(sk)->hdev; if (hdev) { if (hci_pi(sk)->channel == HCI_CHANNEL_USER) { /* When releasing a user channel exclusive access, @@ -1383,9 +1382,9 @@ static void hci_sock_cmsg(struct sock *sk, struct msghdr *msg, if (mask & HCI_CMSG_TSTAMP) { #ifdef CONFIG_COMPAT - struct compat_timeval ctv; + struct old_timeval32 ctv; #endif - struct timeval tv; + struct __kernel_old_timeval tv; void *data; int len; diff --git a/net/bluetooth/l2cap_core.c b/net/bluetooth/l2cap_core.c index 2a7fb517d460..f17e393b43b4 100644 --- a/net/bluetooth/l2cap_core.c +++ b/net/bluetooth/l2cap_core.c @@ -3337,16 +3337,22 @@ static int l2cap_parse_conf_req(struct l2cap_chan *chan, void *data, size_t data while (len >= L2CAP_CONF_OPT_SIZE) { len -= l2cap_get_conf_opt(&req, &type, &olen, &val); + if (len < 0) + break; hint = type & L2CAP_CONF_HINT; type &= L2CAP_CONF_MASK; switch (type) { case L2CAP_CONF_MTU: + if (olen != 2) + break; mtu = val; break; case L2CAP_CONF_FLUSH_TO: + if (olen != 2) + break; chan->flush_to = val; break; @@ -3354,26 +3360,30 @@ static int l2cap_parse_conf_req(struct l2cap_chan *chan, void *data, size_t data break; case L2CAP_CONF_RFC: - if (olen == sizeof(rfc)) - memcpy(&rfc, (void *) val, olen); + if (olen != sizeof(rfc)) + break; + memcpy(&rfc, (void *) val, olen); break; case L2CAP_CONF_FCS: + if (olen != 1) + break; if (val == L2CAP_FCS_NONE) set_bit(CONF_RECV_NO_FCS, &chan->conf_state); break; case L2CAP_CONF_EFS: - if (olen == sizeof(efs)) { - remote_efs = 1; - memcpy(&efs, (void *) val, olen); - } + if (olen != sizeof(efs)) + break; + remote_efs = 1; + memcpy(&efs, (void *) val, olen); break; case L2CAP_CONF_EWS: + if (olen != 2) + break; if (!(chan->conn->local_fixed_chan & L2CAP_FC_A2MP)) return -ECONNREFUSED; - set_bit(FLAG_EXT_CTRL, &chan->flags); set_bit(CONF_EWS_RECV, &chan->conf_state); chan->tx_win_max = L2CAP_DEFAULT_EXT_WINDOW; @@ -3383,7 +3393,6 @@ static int l2cap_parse_conf_req(struct l2cap_chan *chan, void *data, size_t data default: if (hint) break; - result = L2CAP_CONF_UNKNOWN; *((u8 *) ptr++) = type; break; @@ -3548,58 +3557,65 @@ static int l2cap_parse_conf_rsp(struct l2cap_chan *chan, void *rsp, int len, while (len >= L2CAP_CONF_OPT_SIZE) { len -= l2cap_get_conf_opt(&rsp, &type, &olen, &val); + if (len < 0) + break; switch (type) { case L2CAP_CONF_MTU: + if (olen != 2) + break; if (val < L2CAP_DEFAULT_MIN_MTU) { *result = L2CAP_CONF_UNACCEPT; chan->imtu = L2CAP_DEFAULT_MIN_MTU; } else chan->imtu = val; - l2cap_add_conf_opt(&ptr, L2CAP_CONF_MTU, 2, chan->imtu, endptr - ptr); + l2cap_add_conf_opt(&ptr, L2CAP_CONF_MTU, 2, chan->imtu, + endptr - ptr); break; case L2CAP_CONF_FLUSH_TO: + if (olen != 2) + break; chan->flush_to = val; - l2cap_add_conf_opt(&ptr, L2CAP_CONF_FLUSH_TO, - 2, chan->flush_to, endptr - ptr); + l2cap_add_conf_opt(&ptr, L2CAP_CONF_FLUSH_TO, 2, + chan->flush_to, endptr - ptr); break; case L2CAP_CONF_RFC: - if (olen == sizeof(rfc)) - memcpy(&rfc, (void *)val, olen); - + if (olen != sizeof(rfc)) + break; + memcpy(&rfc, (void *)val, olen); if (test_bit(CONF_STATE2_DEVICE, &chan->conf_state) && rfc.mode != chan->mode) return -ECONNREFUSED; - chan->fcs = 0; - - l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, - sizeof(rfc), (unsigned long) &rfc, endptr - ptr); + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); break; case L2CAP_CONF_EWS: + if (olen != 2) + break; chan->ack_win = min_t(u16, val, chan->ack_win); l2cap_add_conf_opt(&ptr, L2CAP_CONF_EWS, 2, chan->tx_win, endptr - ptr); break; case L2CAP_CONF_EFS: - if (olen == sizeof(efs)) { - memcpy(&efs, (void *)val, olen); - - if (chan->local_stype != L2CAP_SERV_NOTRAFIC && - efs.stype != L2CAP_SERV_NOTRAFIC && - efs.stype != chan->local_stype) - return -ECONNREFUSED; - - l2cap_add_conf_opt(&ptr, L2CAP_CONF_EFS, sizeof(efs), - (unsigned long) &efs, endptr - ptr); - } + if (olen != sizeof(efs)) + break; + memcpy(&efs, (void *)val, olen); + if (chan->local_stype != L2CAP_SERV_NOTRAFIC && + efs.stype != L2CAP_SERV_NOTRAFIC && + efs.stype != chan->local_stype) + return -ECONNREFUSED; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_EFS, sizeof(efs), + (unsigned long) &efs, endptr - ptr); break; case L2CAP_CONF_FCS: + if (olen != 1) + break; if (*result == L2CAP_CONF_PENDING) if (val == L2CAP_FCS_NONE) set_bit(CONF_RECV_NO_FCS, @@ -3728,13 +3744,18 @@ static void l2cap_conf_rfc_get(struct l2cap_chan *chan, void *rsp, int len) while (len >= L2CAP_CONF_OPT_SIZE) { len -= l2cap_get_conf_opt(&rsp, &type, &olen, &val); + if (len < 0) + break; switch (type) { case L2CAP_CONF_RFC: - if (olen == sizeof(rfc)) - memcpy(&rfc, (void *)val, olen); + if (olen != sizeof(rfc)) + break; + memcpy(&rfc, (void *)val, olen); break; case L2CAP_CONF_EWS: + if (olen != 2) + break; txwin_ext = val; break; } @@ -4244,6 +4265,7 @@ static inline int l2cap_config_rsp(struct l2cap_conn *conn, goto done; break; } + /* fall through */ default: l2cap_chan_set_err(chan, ECONNRESET); diff --git a/net/bluetooth/l2cap_sock.c b/net/bluetooth/l2cap_sock.c index 686bdc6b35b0..a3a2cd55e23a 100644 --- a/net/bluetooth/l2cap_sock.c +++ b/net/bluetooth/l2cap_sock.c @@ -1252,7 +1252,7 @@ static struct l2cap_chan *l2cap_sock_new_connection_cb(struct l2cap_chan *chan) l2cap_sock_init(sk, parent); - bt_accept_enqueue(parent, sk); + bt_accept_enqueue(parent, sk, false); release_sock(parent); diff --git a/net/bluetooth/mgmt.c b/net/bluetooth/mgmt.c index ccce954f8146..2457f408d17d 100644 --- a/net/bluetooth/mgmt.c +++ b/net/bluetooth/mgmt.c @@ -474,7 +474,6 @@ static int read_ext_index_list(struct sock *sk, struct hci_dev *hdev, { struct mgmt_rp_read_ext_index_list *rp; struct hci_dev *d; - size_t rp_len; u16 count; int err; @@ -488,8 +487,7 @@ static int read_ext_index_list(struct sock *sk, struct hci_dev *hdev, count++; } - rp_len = sizeof(*rp) + (sizeof(rp->entry[0]) * count); - rp = kmalloc(rp_len, GFP_ATOMIC); + rp = kmalloc(struct_size(rp, entry, count), GFP_ATOMIC); if (!rp) { read_unlock(&hci_dev_list_lock); return -ENOMEM; @@ -525,7 +523,6 @@ static int read_ext_index_list(struct sock *sk, struct hci_dev *hdev, } rp->num_controllers = cpu_to_le16(count); - rp_len = sizeof(*rp) + (sizeof(rp->entry[0]) * count); read_unlock(&hci_dev_list_lock); @@ -538,7 +535,8 @@ static int read_ext_index_list(struct sock *sk, struct hci_dev *hdev, hci_sock_clear_flag(sk, HCI_MGMT_UNCONF_INDEX_EVENTS); err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, - MGMT_OP_READ_EXT_INDEX_LIST, 0, rp, rp_len); + MGMT_OP_READ_EXT_INDEX_LIST, 0, rp, + struct_size(rp, entry, count)); kfree(rp); @@ -551,7 +549,8 @@ static bool is_configured(struct hci_dev *hdev) !hci_dev_test_flag(hdev, HCI_EXT_CONFIGURED)) return false; - if (test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks) && + if ((test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks) || + test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks)) && !bacmp(&hdev->public_addr, BDADDR_ANY)) return false; @@ -566,7 +565,8 @@ static __le32 get_missing_options(struct hci_dev *hdev) !hci_dev_test_flag(hdev, HCI_EXT_CONFIGURED)) options |= MGMT_OPTION_EXTERNAL_CONFIG; - if (test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks) && + if ((test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks) || + test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks)) && !bacmp(&hdev->public_addr, BDADDR_ANY)) options |= MGMT_OPTION_PUBLIC_ADDRESS; diff --git a/net/bluetooth/rfcomm/core.c b/net/bluetooth/rfcomm/core.c index 1a635df80643..3a9e9d9670be 100644 --- a/net/bluetooth/rfcomm/core.c +++ b/net/bluetooth/rfcomm/core.c @@ -483,6 +483,7 @@ static int __rfcomm_dlc_close(struct rfcomm_dlc *d, int err) /* if closing a dlc in a session that hasn't been started, * just close and unlink the dlc */ + /* fall through */ default: rfcomm_dlc_clear_timer(d); diff --git a/net/bluetooth/rfcomm/sock.c b/net/bluetooth/rfcomm/sock.c index aa0db1d1bd9b..b1f49fcc0478 100644 --- a/net/bluetooth/rfcomm/sock.c +++ b/net/bluetooth/rfcomm/sock.c @@ -988,7 +988,7 @@ int rfcomm_connect_ind(struct rfcomm_session *s, u8 channel, struct rfcomm_dlc * rfcomm_pi(sk)->channel = channel; sk->sk_state = BT_CONFIG; - bt_accept_enqueue(parent, sk); + bt_accept_enqueue(parent, sk, true); /* Accept connection and return socket DLC */ *d = rfcomm_pi(sk)->dlc; diff --git a/net/bluetooth/sco.c b/net/bluetooth/sco.c index 529b38996d8b..9a580999ca57 100644 --- a/net/bluetooth/sco.c +++ b/net/bluetooth/sco.c @@ -193,7 +193,7 @@ static void __sco_chan_add(struct sco_conn *conn, struct sock *sk, conn->sk = sk; if (parent) - bt_accept_enqueue(parent, sk); + bt_accept_enqueue(parent, sk, true); } static int sco_chan_add(struct sco_conn *conn, struct sock *sk, diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index fa2644d276ef..fab142b796ef 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -13,27 +13,13 @@ #include <net/sock.h> #include <net/tcp.h> -static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx, - struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE]) +static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, + u32 *retval, u32 *time) { - u32 ret; - - preempt_disable(); - rcu_read_lock(); - bpf_cgroup_storage_set(storage); - ret = BPF_PROG_RUN(prog, ctx); - rcu_read_unlock(); - preempt_enable(); - - return ret; -} - -static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *ret, - u32 *time) -{ - struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { 0 }; + struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { NULL }; enum bpf_cgroup_storage_type stype; u64 time_start, time_spent = 0; + int ret = 0; u32 i; for_each_cgroup_storage_type(stype) { @@ -48,25 +34,42 @@ static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *ret, if (!repeat) repeat = 1; + + rcu_read_lock(); + preempt_disable(); time_start = ktime_get_ns(); for (i = 0; i < repeat; i++) { - *ret = bpf_test_run_one(prog, ctx, storage); + bpf_cgroup_storage_set(storage); + *retval = BPF_PROG_RUN(prog, ctx); + + if (signal_pending(current)) { + ret = -EINTR; + break; + } + if (need_resched()) { - if (signal_pending(current)) - break; time_spent += ktime_get_ns() - time_start; + preempt_enable(); + rcu_read_unlock(); + cond_resched(); + + rcu_read_lock(); + preempt_disable(); time_start = ktime_get_ns(); } } time_spent += ktime_get_ns() - time_start; + preempt_enable(); + rcu_read_unlock(); + do_div(time_spent, repeat); *time = time_spent > U32_MAX ? U32_MAX : (u32)time_spent; for_each_cgroup_storage_type(stype) bpf_cgroup_storage_free(storage[stype]); - return 0; + return ret; } static int bpf_test_finish(const union bpf_attr *kattr, @@ -240,3 +243,99 @@ out: kfree(data); return ret; } + +int bpf_prog_test_run_flow_dissector(struct bpf_prog *prog, + const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + u32 size = kattr->test.data_size_in; + u32 repeat = kattr->test.repeat; + struct bpf_flow_keys flow_keys; + u64 time_start, time_spent = 0; + struct bpf_skb_data_end *cb; + u32 retval, duration; + struct sk_buff *skb; + struct sock *sk; + void *data; + int ret; + u32 i; + + if (prog->type != BPF_PROG_TYPE_FLOW_DISSECTOR) + return -EINVAL; + + data = bpf_test_init(kattr, size, NET_SKB_PAD + NET_IP_ALIGN, + SKB_DATA_ALIGN(sizeof(struct skb_shared_info))); + if (IS_ERR(data)) + return PTR_ERR(data); + + sk = kzalloc(sizeof(*sk), GFP_USER); + if (!sk) { + kfree(data); + return -ENOMEM; + } + sock_net_set(sk, current->nsproxy->net_ns); + sock_init_data(NULL, sk); + + skb = build_skb(data, 0); + if (!skb) { + kfree(data); + kfree(sk); + return -ENOMEM; + } + skb->sk = sk; + + skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN); + __skb_put(skb, size); + skb->protocol = eth_type_trans(skb, + current->nsproxy->net_ns->loopback_dev); + skb_reset_network_header(skb); + + cb = (struct bpf_skb_data_end *)skb->cb; + cb->qdisc_cb.flow_keys = &flow_keys; + + if (!repeat) + repeat = 1; + + rcu_read_lock(); + preempt_disable(); + time_start = ktime_get_ns(); + for (i = 0; i < repeat; i++) { + retval = __skb_flow_bpf_dissect(prog, skb, + &flow_keys_dissector, + &flow_keys); + + if (signal_pending(current)) { + preempt_enable(); + rcu_read_unlock(); + + ret = -EINTR; + goto out; + } + + if (need_resched()) { + time_spent += ktime_get_ns() - time_start; + preempt_enable(); + rcu_read_unlock(); + + cond_resched(); + + rcu_read_lock(); + preempt_disable(); + time_start = ktime_get_ns(); + } + } + time_spent += ktime_get_ns() - time_start; + preempt_enable(); + rcu_read_unlock(); + + do_div(time_spent, repeat); + duration = time_spent > U32_MAX ? U32_MAX : (u32)time_spent; + + ret = bpf_test_finish(kattr, uattr, &flow_keys, sizeof(flow_keys), + retval, duration); + +out: + kfree_skb(skb); + kfree(sk); + return ret; +} diff --git a/net/bpfilter/Makefile b/net/bpfilter/Makefile index 0947ee7f70d5..854395fb98cd 100644 --- a/net/bpfilter/Makefile +++ b/net/bpfilter/Makefile @@ -5,7 +5,7 @@ hostprogs-y := bpfilter_umh bpfilter_umh-objs := main.o -KBUILD_HOSTCFLAGS += -I. -Itools/include/ -Itools/include/uapi +KBUILD_HOSTCFLAGS += -Itools/include/ -Itools/include/uapi HOSTCC := $(CC) ifeq ($(CONFIG_BPFILTER_UMH), y) diff --git a/net/bpfilter/main.c b/net/bpfilter/main.c index 1317f108df8a..61ce8454a88e 100644 --- a/net/bpfilter/main.c +++ b/net/bpfilter/main.c @@ -6,7 +6,7 @@ #include <sys/socket.h> #include <fcntl.h> #include <unistd.h> -#include "include/uapi/linux/bpf.h" +#include "../../include/uapi/linux/bpf.h" #include <asm/unistd.h> #include "msgfmt.h" diff --git a/net/bridge/br_fdb.c b/net/bridge/br_fdb.c index 9e14767500ea..00573cc46c98 100644 --- a/net/bridge/br_fdb.c +++ b/net/bridge/br_fdb.c @@ -915,7 +915,8 @@ static int __br_fdb_add(struct ndmsg *ndm, struct net_bridge *br, /* Add new permanent fdb entry with RTM_NEWNEIGH */ int br_fdb_add(struct ndmsg *ndm, struct nlattr *tb[], struct net_device *dev, - const unsigned char *addr, u16 vid, u16 nlh_flags) + const unsigned char *addr, u16 vid, u16 nlh_flags, + struct netlink_ext_ack *extack) { struct net_bridge_vlan_group *vg; struct net_bridge_port *p = NULL; diff --git a/net/bridge/br_multicast.c b/net/bridge/br_multicast.c index 3aeff0895669..a0e369179f6d 100644 --- a/net/bridge/br_multicast.c +++ b/net/bridge/br_multicast.c @@ -14,6 +14,7 @@ #include <linux/export.h> #include <linux/if_ether.h> #include <linux/igmp.h> +#include <linux/in.h> #include <linux/jhash.h> #include <linux/kernel.h> #include <linux/log2.h> @@ -29,6 +30,7 @@ #include <net/ip.h> #include <net/switchdev.h> #if IS_ENABLED(CONFIG_IPV6) +#include <linux/icmpv6.h> #include <net/ipv6.h> #include <net/mld.h> #include <net/ip6_checksum.h> @@ -938,7 +940,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br, for (i = 0; i < num; i++) { len += sizeof(*grec); - if (!pskb_may_pull(skb, len)) + if (!ip_mc_may_pull(skb, len)) return -EINVAL; grec = (void *)(skb->data + len - sizeof(*grec)); @@ -946,7 +948,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br, type = grec->grec_type; len += ntohs(grec->grec_nsrcs) * 4; - if (!pskb_may_pull(skb, len)) + if (!ip_mc_may_pull(skb, len)) return -EINVAL; /* We treat this as an IGMPv2 report for now. */ @@ -985,15 +987,17 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br, struct sk_buff *skb, u16 vid) { + unsigned int nsrcs_offset; const unsigned char *src; struct icmp6hdr *icmp6h; struct mld2_grec *grec; + unsigned int grec_len; int i; int len; int num; int err = 0; - if (!pskb_may_pull(skb, sizeof(*icmp6h))) + if (!ipv6_mc_may_pull(skb, sizeof(*icmp6h))) return -EINVAL; icmp6h = icmp6_hdr(skb); @@ -1003,21 +1007,24 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br, for (i = 0; i < num; i++) { __be16 *nsrcs, _nsrcs; - nsrcs = skb_header_pointer(skb, - len + offsetof(struct mld2_grec, - grec_nsrcs), + nsrcs_offset = len + offsetof(struct mld2_grec, grec_nsrcs); + + if (skb_transport_offset(skb) + ipv6_transport_len(skb) < + nsrcs_offset + sizeof(_nsrcs)) + return -EINVAL; + + nsrcs = skb_header_pointer(skb, nsrcs_offset, sizeof(_nsrcs), &_nsrcs); if (!nsrcs) return -EINVAL; - if (!pskb_may_pull(skb, - len + sizeof(*grec) + - sizeof(struct in6_addr) * ntohs(*nsrcs))) + grec_len = struct_size(grec, grec_src, ntohs(*nsrcs)); + + if (!ipv6_mc_may_pull(skb, len + grec_len)) return -EINVAL; grec = (struct mld2_grec *)(skb->data + len); - len += sizeof(*grec) + - sizeof(struct in6_addr) * ntohs(*nsrcs); + len += grec_len; /* We treat these as MLDv1 reports for now. */ switch (grec->grec_type) { @@ -1204,14 +1211,7 @@ static void br_multicast_query_received(struct net_bridge *br, return; br_multicast_update_query_timer(br, query, max_delay); - - /* Based on RFC4541, section 2.1.1 IGMP Forwarding Rules, - * the arrival port for IGMP Queries where the source address - * is 0.0.0.0 should not be added to router port list. - */ - if ((saddr->proto == htons(ETH_P_IP) && saddr->u.ip4) || - saddr->proto == htons(ETH_P_IPV6)) - br_multicast_mark_router(br, port); + br_multicast_mark_router(br, port); } static void br_ip4_multicast_query(struct net_bridge *br, @@ -1219,6 +1219,7 @@ static void br_ip4_multicast_query(struct net_bridge *br, struct sk_buff *skb, u16 vid) { + unsigned int transport_len = ip_transport_len(skb); const struct iphdr *iph = ip_hdr(skb); struct igmphdr *ih = igmp_hdr(skb); struct net_bridge_mdb_entry *mp; @@ -1228,7 +1229,6 @@ static void br_ip4_multicast_query(struct net_bridge *br, struct br_ip saddr; unsigned long max_delay; unsigned long now = jiffies; - unsigned int offset = skb_transport_offset(skb); __be32 group; spin_lock(&br->multicast_lock); @@ -1238,14 +1238,14 @@ static void br_ip4_multicast_query(struct net_bridge *br, group = ih->group; - if (skb->len == offset + sizeof(*ih)) { + if (transport_len == sizeof(*ih)) { max_delay = ih->code * (HZ / IGMP_TIMER_SCALE); if (!max_delay) { max_delay = 10 * HZ; group = 0; } - } else if (skb->len >= offset + sizeof(*ih3)) { + } else if (transport_len >= sizeof(*ih3)) { ih3 = igmpv3_query_hdr(skb); if (ih3->nsrcs) goto out; @@ -1296,6 +1296,7 @@ static int br_ip6_multicast_query(struct net_bridge *br, struct sk_buff *skb, u16 vid) { + unsigned int transport_len = ipv6_transport_len(skb); const struct ipv6hdr *ip6h = ipv6_hdr(skb); struct mld_msg *mld; struct net_bridge_mdb_entry *mp; @@ -1315,7 +1316,7 @@ static int br_ip6_multicast_query(struct net_bridge *br, (port && port->state == BR_STATE_DISABLED)) goto out; - if (skb->len == offset + sizeof(*mld)) { + if (transport_len == sizeof(*mld)) { if (!pskb_may_pull(skb, offset + sizeof(*mld))) { err = -EINVAL; goto out; @@ -1576,17 +1577,29 @@ static void br_multicast_pim(struct net_bridge *br, br_multicast_mark_router(br, port); } +static int br_ip4_multicast_mrd_rcv(struct net_bridge *br, + struct net_bridge_port *port, + struct sk_buff *skb) +{ + if (ip_hdr(skb)->protocol != IPPROTO_IGMP || + igmp_hdr(skb)->type != IGMP_MRDISC_ADV) + return -ENOMSG; + + br_multicast_mark_router(br, port); + + return 0; +} + static int br_multicast_ipv4_rcv(struct net_bridge *br, struct net_bridge_port *port, struct sk_buff *skb, u16 vid) { - struct sk_buff *skb_trimmed = NULL; const unsigned char *src; struct igmphdr *ih; int err; - err = ip_mc_check_igmp(skb, &skb_trimmed); + err = ip_mc_check_igmp(skb); if (err == -ENOMSG) { if (!ipv4_is_local_multicast(ip_hdr(skb)->daddr)) { @@ -1594,7 +1607,10 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br, } else if (pim_ipv4_all_pim_routers(ip_hdr(skb)->daddr)) { if (ip_hdr(skb)->protocol == IPPROTO_PIM) br_multicast_pim(br, port, skb); + } else if (ipv4_is_all_snoopers(ip_hdr(skb)->daddr)) { + br_ip4_multicast_mrd_rcv(br, port, skb); } + return 0; } else if (err < 0) { br_multicast_err_count(br, port, skb->protocol); @@ -1612,19 +1628,16 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br, err = br_ip4_multicast_add_group(br, port, ih->group, vid, src); break; case IGMPV3_HOST_MEMBERSHIP_REPORT: - err = br_ip4_multicast_igmp3_report(br, port, skb_trimmed, vid); + err = br_ip4_multicast_igmp3_report(br, port, skb, vid); break; case IGMP_HOST_MEMBERSHIP_QUERY: - br_ip4_multicast_query(br, port, skb_trimmed, vid); + br_ip4_multicast_query(br, port, skb, vid); break; case IGMP_HOST_LEAVE_MESSAGE: br_ip4_multicast_leave_group(br, port, ih->group, vid, src); break; } - if (skb_trimmed && skb_trimmed != skb) - kfree_skb(skb_trimmed); - br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp, BR_MCAST_DIR_RX); @@ -1632,21 +1645,51 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br, } #if IS_ENABLED(CONFIG_IPV6) +static int br_ip6_multicast_mrd_rcv(struct net_bridge *br, + struct net_bridge_port *port, + struct sk_buff *skb) +{ + int ret; + + if (ipv6_hdr(skb)->nexthdr != IPPROTO_ICMPV6) + return -ENOMSG; + + ret = ipv6_mc_check_icmpv6(skb); + if (ret < 0) + return ret; + + if (icmp6_hdr(skb)->icmp6_type != ICMPV6_MRDISC_ADV) + return -ENOMSG; + + br_multicast_mark_router(br, port); + + return 0; +} + static int br_multicast_ipv6_rcv(struct net_bridge *br, struct net_bridge_port *port, struct sk_buff *skb, u16 vid) { - struct sk_buff *skb_trimmed = NULL; const unsigned char *src; struct mld_msg *mld; int err; - err = ipv6_mc_check_mld(skb, &skb_trimmed); + err = ipv6_mc_check_mld(skb); if (err == -ENOMSG) { if (!ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr)) BR_INPUT_SKB_CB(skb)->mrouters_only = 1; + + if (ipv6_addr_is_all_snoopers(&ipv6_hdr(skb)->daddr)) { + err = br_ip6_multicast_mrd_rcv(br, port, skb); + + if (err < 0 && err != -ENOMSG) { + br_multicast_err_count(br, port, skb->protocol); + return err; + } + } + return 0; } else if (err < 0) { br_multicast_err_count(br, port, skb->protocol); @@ -1664,10 +1707,10 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br, src); break; case ICMPV6_MLD2_REPORT: - err = br_ip6_multicast_mld2_report(br, port, skb_trimmed, vid); + err = br_ip6_multicast_mld2_report(br, port, skb, vid); break; case ICMPV6_MGM_QUERY: - err = br_ip6_multicast_query(br, port, skb_trimmed, vid); + err = br_ip6_multicast_query(br, port, skb, vid); break; case ICMPV6_MGM_REDUCTION: src = eth_hdr(skb)->h_source; @@ -1675,9 +1718,6 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br, break; } - if (skb_trimmed && skb_trimmed != skb) - kfree_skb(skb_trimmed); - br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp, BR_MCAST_DIR_RX); @@ -1781,6 +1821,68 @@ void br_multicast_init(struct net_bridge *br) INIT_HLIST_HEAD(&br->mdb_list); } +static void br_ip4_multicast_join_snoopers(struct net_bridge *br) +{ + struct in_device *in_dev = in_dev_get(br->dev); + + if (!in_dev) + return; + + __ip_mc_inc_group(in_dev, htonl(INADDR_ALLSNOOPERS_GROUP), GFP_ATOMIC); + in_dev_put(in_dev); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_join_snoopers(struct net_bridge *br) +{ + struct in6_addr addr; + + ipv6_addr_set(&addr, htonl(0xff020000), 0, 0, htonl(0x6a)); + ipv6_dev_mc_inc(br->dev, &addr); +} +#else +static inline void br_ip6_multicast_join_snoopers(struct net_bridge *br) +{ +} +#endif + +static void br_multicast_join_snoopers(struct net_bridge *br) +{ + br_ip4_multicast_join_snoopers(br); + br_ip6_multicast_join_snoopers(br); +} + +static void br_ip4_multicast_leave_snoopers(struct net_bridge *br) +{ + struct in_device *in_dev = in_dev_get(br->dev); + + if (WARN_ON(!in_dev)) + return; + + __ip_mc_dec_group(in_dev, htonl(INADDR_ALLSNOOPERS_GROUP), GFP_ATOMIC); + in_dev_put(in_dev); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_leave_snoopers(struct net_bridge *br) +{ + struct in6_addr addr; + + ipv6_addr_set(&addr, htonl(0xff020000), 0, 0, htonl(0x6a)); + ipv6_dev_mc_dec(br->dev, &addr); +} +#else +static inline void br_ip6_multicast_leave_snoopers(struct net_bridge *br) +{ +} +#endif + +static void br_multicast_leave_snoopers(struct net_bridge *br) +{ + br_ip4_multicast_leave_snoopers(br); + br_ip6_multicast_leave_snoopers(br); +} + static void __br_multicast_open(struct net_bridge *br, struct bridge_mcast_own_query *query) { @@ -1794,6 +1896,9 @@ static void __br_multicast_open(struct net_bridge *br, void br_multicast_open(struct net_bridge *br) { + if (br_opt_get(br, BROPT_MULTICAST_ENABLED)) + br_multicast_join_snoopers(br); + __br_multicast_open(br, &br->ip4_own_query); #if IS_ENABLED(CONFIG_IPV6) __br_multicast_open(br, &br->ip6_own_query); @@ -1809,6 +1914,9 @@ void br_multicast_stop(struct net_bridge *br) del_timer_sync(&br->ip6_other_query.timer); del_timer_sync(&br->ip6_own_query.timer); #endif + + if (br_opt_get(br, BROPT_MULTICAST_ENABLED)) + br_multicast_leave_snoopers(br); } void br_multicast_dev_del(struct net_bridge *br) @@ -1944,8 +2052,10 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val) br_mc_disabled_update(br->dev, val); br_opt_toggle(br, BROPT_MULTICAST_ENABLED, !!val); - if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) { + br_multicast_leave_snoopers(br); goto unlock; + } if (!netif_running(br->dev)) goto unlock; diff --git a/net/bridge/br_netfilter_hooks.c b/net/bridge/br_netfilter_hooks.c index c93c35bb73dd..9d34de68571b 100644 --- a/net/bridge/br_netfilter_hooks.c +++ b/net/bridge/br_netfilter_hooks.c @@ -831,7 +831,8 @@ static unsigned int ip_sabotage_in(void *priv, struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); if (nf_bridge && !nf_bridge->in_prerouting && - !netif_is_l3_master(skb->dev)) { + !netif_is_l3_master(skb->dev) && + !netif_is_l3_slave(skb->dev)) { state->okfn(state->net, state->sk, skb); return NF_STOLEN; } @@ -881,11 +882,6 @@ static const struct nf_br_ops br_ops = { .br_dev_xmit_hook = br_nf_dev_xmit, }; -void br_netfilter_enable(void) -{ -} -EXPORT_SYMBOL_GPL(br_netfilter_enable); - /* For br_nf_post_routing, we need (prio = NF_BR_PRI_LAST), because * br_dev_queue_push_xmit is called afterwards */ static const struct nf_hook_ops br_nf_ops[] = { diff --git a/net/bridge/br_private.h b/net/bridge/br_private.h index eabf8bf28a3f..00deef7fc1f3 100644 --- a/net/bridge/br_private.h +++ b/net/bridge/br_private.h @@ -573,7 +573,8 @@ void br_fdb_update(struct net_bridge *br, struct net_bridge_port *source, int br_fdb_delete(struct ndmsg *ndm, struct nlattr *tb[], struct net_device *dev, const unsigned char *addr, u16 vid); int br_fdb_add(struct ndmsg *nlh, struct nlattr *tb[], struct net_device *dev, - const unsigned char *addr, u16 vid, u16 nlh_flags); + const unsigned char *addr, u16 vid, u16 nlh_flags, + struct netlink_ext_ack *extack); int br_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb, struct net_device *dev, struct net_device *fdev, int *idx); int br_fdb_get(struct sk_buff *skb, struct nlattr *tb[], struct net_device *dev, diff --git a/net/bridge/br_switchdev.c b/net/bridge/br_switchdev.c index 035ff59d9cbd..921310d3cbae 100644 --- a/net/bridge/br_switchdev.c +++ b/net/bridge/br_switchdev.c @@ -14,7 +14,7 @@ static int br_switchdev_mark_get(struct net_bridge *br, struct net_device *dev) /* dev is yet to be added to the port list. */ list_for_each_entry(p, &br->port_list, list) { - if (switchdev_port_same_parent_id(dev, p->dev)) + if (netdev_port_same_parent_id(dev, p->dev)) return p->offload_fwd_mark; } @@ -23,15 +23,12 @@ static int br_switchdev_mark_get(struct net_bridge *br, struct net_device *dev) int nbp_switchdev_mark_set(struct net_bridge_port *p) { - struct switchdev_attr attr = { - .orig_dev = p->dev, - .id = SWITCHDEV_ATTR_ID_PORT_PARENT_ID, - }; + struct netdev_phys_item_id ppid = { }; int err; ASSERT_RTNL(); - err = switchdev_port_attr_get(p->dev, &attr); + err = dev_get_port_parent_id(p->dev, &ppid, true); if (err) { if (err == -EOPNOTSUPP) return 0; @@ -67,21 +64,25 @@ int br_switchdev_set_port_flag(struct net_bridge_port *p, { struct switchdev_attr attr = { .orig_dev = p->dev, - .id = SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS_SUPPORT, + .id = SWITCHDEV_ATTR_ID_PORT_PRE_BRIDGE_FLAGS, + .u.brport_flags = mask, + }; + struct switchdev_notifier_port_attr_info info = { + .attr = &attr, }; int err; if (mask & ~BR_PORT_FLAGS_HW_OFFLOAD) return 0; - err = switchdev_port_attr_get(p->dev, &attr); + /* We run from atomic context here */ + err = call_switchdev_notifiers(SWITCHDEV_PORT_ATTR_SET, p->dev, + &info.info, NULL); + err = notifier_to_errno(err); if (err == -EOPNOTSUPP) return 0; - if (err) - return err; - /* Check if specific bridge flag attribute offload is supported */ - if (!(attr.u.brport_flags_support & mask)) { + if (err) { br_warn(p->br, "bridge flag offload is not supported %u(%s)\n", (unsigned int)p->port_no, p->dev->name); return -EOPNOTSUPP; @@ -90,6 +91,7 @@ int br_switchdev_set_port_flag(struct net_bridge_port *p, attr.id = SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS; attr.flags = SWITCHDEV_F_DEFER; attr.u.brport_flags = flags; + err = switchdev_port_attr_set(p->dev, &attr); if (err) { br_warn(p->br, "error setting offload flag on port %u(%s)\n", @@ -113,7 +115,7 @@ br_switchdev_fdb_call_notifiers(bool adding, const unsigned char *mac, info.added_by_user = added_by_user; info.offloaded = offloaded; notifier_type = adding ? SWITCHDEV_FDB_ADD_TO_DEVICE : SWITCHDEV_FDB_DEL_TO_DEVICE; - call_switchdev_notifiers(notifier_type, dev, &info.info); + call_switchdev_notifiers(notifier_type, dev, &info.info, NULL); } void diff --git a/net/bridge/netfilter/ebtables.c b/net/bridge/netfilter/ebtables.c index 5e55cef0cec3..eb15891f8b9f 100644 --- a/net/bridge/netfilter/ebtables.c +++ b/net/bridge/netfilter/ebtables.c @@ -31,10 +31,6 @@ /* needed for logical [in,out]-dev filtering */ #include "../br_private.h" -#define BUGPRINT(format, args...) printk("kernel msg: ebtables bug: please "\ - "report to author: "format, ## args) -/* #define BUGPRINT(format, args...) */ - /* Each cpu has its own set of counters, so there is no need for write_lock in * the softirq * For reading or updating the counters, the user context needs to @@ -385,7 +381,7 @@ ebt_check_match(struct ebt_entry_match *m, struct xt_mtchk_param *par, par->match = match; par->matchinfo = m->data; ret = xt_check_match(par, m->match_size, - e->ethproto, e->invflags & EBT_IPROTO); + ntohs(e->ethproto), e->invflags & EBT_IPROTO); if (ret < 0) { module_put(match->me); return ret; @@ -422,7 +418,7 @@ ebt_check_watcher(struct ebt_entry_watcher *w, struct xt_tgchk_param *par, par->target = watcher; par->targinfo = w->data; ret = xt_check_target(par, w->watcher_size, - e->ethproto, e->invflags & EBT_IPROTO); + ntohs(e->ethproto), e->invflags & EBT_IPROTO); if (ret < 0) { module_put(watcher->me); return ret; @@ -466,8 +462,6 @@ static int ebt_verify_pointers(const struct ebt_replace *repl, /* we make userspace set this right, * so there is no misunderstanding */ - BUGPRINT("EBT_ENTRY_OR_ENTRIES shouldn't be set " - "in distinguisher\n"); return -EINVAL; } if (i != NF_BR_NUMHOOKS) @@ -485,18 +479,14 @@ static int ebt_verify_pointers(const struct ebt_replace *repl, offset += e->next_offset; } } - if (offset != limit) { - BUGPRINT("entries_size too small\n"); + if (offset != limit) return -EINVAL; - } /* check if all valid hooks have a chain */ for (i = 0; i < NF_BR_NUMHOOKS; i++) { if (!newinfo->hook_entry[i] && - (valid_hooks & (1 << i))) { - BUGPRINT("Valid hook without chain\n"); + (valid_hooks & (1 << i))) return -EINVAL; - } } return 0; } @@ -523,26 +513,20 @@ ebt_check_entry_size_and_hooks(const struct ebt_entry *e, /* this checks if the previous chain has as many entries * as it said it has */ - if (*n != *cnt) { - BUGPRINT("nentries does not equal the nr of entries " - "in the chain\n"); + if (*n != *cnt) return -EINVAL; - } + if (((struct ebt_entries *)e)->policy != EBT_DROP && ((struct ebt_entries *)e)->policy != EBT_ACCEPT) { /* only RETURN from udc */ if (i != NF_BR_NUMHOOKS || - ((struct ebt_entries *)e)->policy != EBT_RETURN) { - BUGPRINT("bad policy\n"); + ((struct ebt_entries *)e)->policy != EBT_RETURN) return -EINVAL; - } } if (i == NF_BR_NUMHOOKS) /* it's a user defined chain */ (*udc_cnt)++; - if (((struct ebt_entries *)e)->counter_offset != *totalcnt) { - BUGPRINT("counter_offset != totalcnt"); + if (((struct ebt_entries *)e)->counter_offset != *totalcnt) return -EINVAL; - } *n = ((struct ebt_entries *)e)->nentries; *cnt = 0; return 0; @@ -550,15 +534,13 @@ ebt_check_entry_size_and_hooks(const struct ebt_entry *e, /* a plain old entry, heh */ if (sizeof(struct ebt_entry) > e->watchers_offset || e->watchers_offset > e->target_offset || - e->target_offset >= e->next_offset) { - BUGPRINT("entry offsets not in right order\n"); + e->target_offset >= e->next_offset) return -EINVAL; - } + /* this is not checked anywhere else */ - if (e->next_offset - e->target_offset < sizeof(struct ebt_entry_target)) { - BUGPRINT("target size too small\n"); + if (e->next_offset - e->target_offset < sizeof(struct ebt_entry_target)) return -EINVAL; - } + (*cnt)++; (*totalcnt)++; return 0; @@ -678,18 +660,15 @@ ebt_check_entry(struct ebt_entry *e, struct net *net, if (e->bitmask == 0) return 0; - if (e->bitmask & ~EBT_F_MASK) { - BUGPRINT("Unknown flag for bitmask\n"); + if (e->bitmask & ~EBT_F_MASK) return -EINVAL; - } - if (e->invflags & ~EBT_INV_MASK) { - BUGPRINT("Unknown flag for inv bitmask\n"); + + if (e->invflags & ~EBT_INV_MASK) return -EINVAL; - } - if ((e->bitmask & EBT_NOPROTO) && (e->bitmask & EBT_802_3)) { - BUGPRINT("NOPROTO & 802_3 not allowed\n"); + + if ((e->bitmask & EBT_NOPROTO) && (e->bitmask & EBT_802_3)) return -EINVAL; - } + /* what hook do we belong to? */ for (i = 0; i < NF_BR_NUMHOOKS; i++) { if (!newinfo->hook_entry[i]) @@ -748,13 +727,11 @@ ebt_check_entry(struct ebt_entry *e, struct net *net, t->u.target = target; if (t->u.target == &ebt_standard_target) { if (gap < sizeof(struct ebt_standard_target)) { - BUGPRINT("Standard target size too big\n"); ret = -EFAULT; goto cleanup_watchers; } if (((struct ebt_standard_target *)t)->verdict < -NUM_STANDARD_TARGETS) { - BUGPRINT("Invalid standard target\n"); ret = -EFAULT; goto cleanup_watchers; } @@ -767,7 +744,7 @@ ebt_check_entry(struct ebt_entry *e, struct net *net, tgpar.target = target; tgpar.targinfo = t->data; ret = xt_check_target(&tgpar, t->target_size, - e->ethproto, e->invflags & EBT_IPROTO); + ntohs(e->ethproto), e->invflags & EBT_IPROTO); if (ret < 0) { module_put(target->me); goto cleanup_watchers; @@ -813,10 +790,9 @@ static int check_chainloops(const struct ebt_entries *chain, struct ebt_cl_stack if (strcmp(t->u.name, EBT_STANDARD_TARGET)) goto letscontinue; if (e->target_offset + sizeof(struct ebt_standard_target) > - e->next_offset) { - BUGPRINT("Standard target size too big\n"); + e->next_offset) return -1; - } + verdict = ((struct ebt_standard_target *)t)->verdict; if (verdict >= 0) { /* jump to another chain */ struct ebt_entries *hlp2 = @@ -825,14 +801,12 @@ static int check_chainloops(const struct ebt_entries *chain, struct ebt_cl_stack if (hlp2 == cl_s[i].cs.chaininfo) break; /* bad destination or loop */ - if (i == udc_cnt) { - BUGPRINT("bad destination\n"); + if (i == udc_cnt) return -1; - } - if (cl_s[i].cs.n) { - BUGPRINT("loop\n"); + + if (cl_s[i].cs.n) return -1; - } + if (cl_s[i].hookmask & (1 << hooknr)) goto letscontinue; /* this can't be 0, so the loop test is correct */ @@ -865,24 +839,21 @@ static int translate_table(struct net *net, const char *name, i = 0; while (i < NF_BR_NUMHOOKS && !newinfo->hook_entry[i]) i++; - if (i == NF_BR_NUMHOOKS) { - BUGPRINT("No valid hooks specified\n"); + if (i == NF_BR_NUMHOOKS) return -EINVAL; - } - if (newinfo->hook_entry[i] != (struct ebt_entries *)newinfo->entries) { - BUGPRINT("Chains don't start at beginning\n"); + + if (newinfo->hook_entry[i] != (struct ebt_entries *)newinfo->entries) return -EINVAL; - } + /* make sure chains are ordered after each other in same order * as their corresponding hooks */ for (j = i + 1; j < NF_BR_NUMHOOKS; j++) { if (!newinfo->hook_entry[j]) continue; - if (newinfo->hook_entry[j] <= newinfo->hook_entry[i]) { - BUGPRINT("Hook order must be followed\n"); + if (newinfo->hook_entry[j] <= newinfo->hook_entry[i]) return -EINVAL; - } + i = j; } @@ -900,15 +871,11 @@ static int translate_table(struct net *net, const char *name, if (ret != 0) return ret; - if (i != j) { - BUGPRINT("nentries does not equal the nr of entries in the " - "(last) chain\n"); + if (i != j) return -EINVAL; - } - if (k != newinfo->nentries) { - BUGPRINT("Total nentries is wrong\n"); + + if (k != newinfo->nentries) return -EINVAL; - } /* get the location of the udc, put them in an array * while we're at it, allocate the chainstack @@ -942,7 +909,6 @@ static int translate_table(struct net *net, const char *name, ebt_get_udc_positions, newinfo, &i, cl_s); /* sanity check */ if (i != udc_cnt) { - BUGPRINT("i != udc_cnt\n"); vfree(cl_s); return -EFAULT; } @@ -1042,7 +1008,6 @@ static int do_replace_finish(struct net *net, struct ebt_replace *repl, goto free_unlock; if (repl->num_counters && repl->num_counters != t->private->nentries) { - BUGPRINT("Wrong nr. of counters requested\n"); ret = -EINVAL; goto free_unlock; } @@ -1118,15 +1083,12 @@ static int do_replace(struct net *net, const void __user *user, if (copy_from_user(&tmp, user, sizeof(tmp)) != 0) return -EFAULT; - if (len != sizeof(tmp) + tmp.entries_size) { - BUGPRINT("Wrong len argument\n"); + if (len != sizeof(tmp) + tmp.entries_size) return -EINVAL; - } - if (tmp.entries_size == 0) { - BUGPRINT("Entries_size never zero\n"); + if (tmp.entries_size == 0) return -EINVAL; - } + /* overflow check */ if (tmp.nentries >= ((INT_MAX - sizeof(struct ebt_table_info)) / NR_CPUS - SMP_CACHE_BYTES) / sizeof(struct ebt_counter)) @@ -1153,7 +1115,6 @@ static int do_replace(struct net *net, const void __user *user, } if (copy_from_user( newinfo->entries, tmp.entries, tmp.entries_size) != 0) { - BUGPRINT("Couldn't copy entries from userspace\n"); ret = -EFAULT; goto free_entries; } @@ -1194,10 +1155,8 @@ int ebt_register_table(struct net *net, const struct ebt_table *input_table, if (input_table == NULL || (repl = input_table->table) == NULL || repl->entries == NULL || repl->entries_size == 0 || - repl->counters != NULL || input_table->private != NULL) { - BUGPRINT("Bad table data for ebt_register_table!!!\n"); + repl->counters != NULL || input_table->private != NULL) return -EINVAL; - } /* Don't add one table to multiple lists. */ table = kmemdup(input_table, sizeof(struct ebt_table), GFP_KERNEL); @@ -1235,13 +1194,10 @@ int ebt_register_table(struct net *net, const struct ebt_table *input_table, ((char *)repl->hook_entry[i] - repl->entries); } ret = translate_table(net, repl->name, newinfo); - if (ret != 0) { - BUGPRINT("Translate_table failed\n"); + if (ret != 0) goto free_chainstack; - } if (table->check && table->check(newinfo, table->valid_hooks)) { - BUGPRINT("The table doesn't like its own initial data, lol\n"); ret = -EINVAL; goto free_chainstack; } @@ -1252,7 +1208,6 @@ int ebt_register_table(struct net *net, const struct ebt_table *input_table, list_for_each_entry(t, &net->xt.tables[NFPROTO_BRIDGE], list) { if (strcmp(t->name, table->name) == 0) { ret = -EEXIST; - BUGPRINT("Table name already exists\n"); goto free_unlock; } } @@ -1320,7 +1275,6 @@ static int do_update_counters(struct net *net, const char *name, goto free_tmp; if (num_counters != t->private->nentries) { - BUGPRINT("Wrong nr of counters\n"); ret = -EINVAL; goto unlock_mutex; } @@ -1447,10 +1401,8 @@ static int copy_counters_to_user(struct ebt_table *t, if (num_counters == 0) return 0; - if (num_counters != nentries) { - BUGPRINT("Num_counters wrong\n"); + if (num_counters != nentries) return -EINVAL; - } counterstmp = vmalloc(array_size(nentries, sizeof(*counterstmp))); if (!counterstmp) @@ -1496,15 +1448,11 @@ static int copy_everything_to_user(struct ebt_table *t, void __user *user, (tmp.num_counters ? nentries * sizeof(struct ebt_counter) : 0)) return -EINVAL; - if (tmp.nentries != nentries) { - BUGPRINT("Nentries wrong\n"); + if (tmp.nentries != nentries) return -EINVAL; - } - if (tmp.entries_size != entries_size) { - BUGPRINT("Wrong size\n"); + if (tmp.entries_size != entries_size) return -EINVAL; - } ret = copy_counters_to_user(t, oldcounters, tmp.counters, tmp.num_counters, nentries); @@ -1576,7 +1524,6 @@ static int do_ebt_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) } mutex_unlock(&ebt_mutex); if (copy_to_user(user, &tmp, *len) != 0) { - BUGPRINT("c2u Didn't work\n"); ret = -EFAULT; break; } @@ -2293,9 +2240,12 @@ static int compat_do_replace(struct net *net, void __user *user, xt_compat_lock(NFPROTO_BRIDGE); - ret = xt_compat_init_offsets(NFPROTO_BRIDGE, tmp.nentries); - if (ret < 0) - goto out_unlock; + if (tmp.nentries) { + ret = xt_compat_init_offsets(NFPROTO_BRIDGE, tmp.nentries); + if (ret < 0) + goto out_unlock; + } + ret = compat_copy_entries(entries_tmp, tmp.entries_size, &state); if (ret < 0) goto out_unlock; diff --git a/net/bridge/netfilter/nft_reject_bridge.c b/net/bridge/netfilter/nft_reject_bridge.c index 419e8edf23ba..1b1856744c80 100644 --- a/net/bridge/netfilter/nft_reject_bridge.c +++ b/net/bridge/netfilter/nft_reject_bridge.c @@ -125,13 +125,10 @@ static void nft_reject_br_send_v4_unreach(struct net *net, if (pskb_trim_rcsum(oldskb, ntohs(ip_hdr(oldskb)->tot_len))) return; - if (ip_hdr(oldskb)->protocol == IPPROTO_TCP || - ip_hdr(oldskb)->protocol == IPPROTO_UDP) - proto = ip_hdr(oldskb)->protocol; - else - proto = 0; + proto = ip_hdr(oldskb)->protocol; if (!skb_csum_unnecessary(oldskb) && + nf_reject_verify_csum(proto) && nf_ip_checksum(oldskb, hook, ip_hdrlen(oldskb), proto)) return; @@ -234,6 +231,9 @@ static bool reject6_br_csum_ok(struct sk_buff *skb, int hook) if (thoff < 0 || thoff >= skb->len || (fo & htons(~0x7)) != 0) return false; + if (!nf_reject_verify_csum(proto)) + return true; + return nf_ip6_checksum(skb, hook, thoff, proto) == 0; } diff --git a/net/caif/cfpkt_skbuff.c b/net/caif/cfpkt_skbuff.c index 38c2b7a890dd..37ac5ca0ffdf 100644 --- a/net/caif/cfpkt_skbuff.c +++ b/net/caif/cfpkt_skbuff.c @@ -319,16 +319,12 @@ struct cfpkt *cfpkt_append(struct cfpkt *dstpkt, if (tmppkt == NULL) return NULL; tmp = pkt_to_skb(tmppkt); - skb_set_tail_pointer(tmp, dstlen); - tmp->len = dstlen; - memcpy(tmp->data, dst->data, dstlen); + skb_put_data(tmp, dst->data, dstlen); cfpkt_destroy(dstpkt); dst = tmp; } - memcpy(skb_tail_pointer(dst), add->data, skb_headlen(add)); + skb_put_data(dst, add->data, skb_headlen(add)); cfpkt_destroy(addpkt); - dst->tail += addlen; - dst->len += addlen; return skb_to_pkt(dst); } @@ -359,13 +355,11 @@ struct cfpkt *cfpkt_split(struct cfpkt *pkt, u16 pos) if (skb2 == NULL) return NULL; + skb_put_data(skb2, split, len2nd); + /* Reduce the length of the original packet */ - skb_set_tail_pointer(skb, pos); - skb->len = pos; + skb_trim(skb, pos); - memcpy(skb2->data, split, len2nd); - skb2->tail += len2nd; - skb2->len += len2nd; skb2->priority = skb->priority; return skb_to_pkt(skb2); } diff --git a/net/can/bcm.c b/net/can/bcm.c index 0af8f0db892a..79bb8afa9c0c 100644 --- a/net/can/bcm.c +++ b/net/can/bcm.c @@ -67,6 +67,9 @@ */ #define MAX_NFRAMES 256 +/* limit timers to 400 days for sending/timeouts */ +#define BCM_TIMER_SEC_MAX (400 * 24 * 60 * 60) + /* use of last_frames[index].flags */ #define RX_RECV 0x40 /* received data for this element */ #define RX_THR 0x80 /* element not been sent due to throttle feature */ @@ -140,6 +143,22 @@ static inline ktime_t bcm_timeval_to_ktime(struct bcm_timeval tv) return ktime_set(tv.tv_sec, tv.tv_usec * NSEC_PER_USEC); } +/* check limitations for timeval provided by user */ +static bool bcm_is_invalid_tv(struct bcm_msg_head *msg_head) +{ + if ((msg_head->ival1.tv_sec < 0) || + (msg_head->ival1.tv_sec > BCM_TIMER_SEC_MAX) || + (msg_head->ival1.tv_usec < 0) || + (msg_head->ival1.tv_usec >= USEC_PER_SEC) || + (msg_head->ival2.tv_sec < 0) || + (msg_head->ival2.tv_sec > BCM_TIMER_SEC_MAX) || + (msg_head->ival2.tv_usec < 0) || + (msg_head->ival2.tv_usec >= USEC_PER_SEC)) + return true; + + return false; +} + #define CFSIZ(flags) ((flags & CAN_FD_FRAME) ? CANFD_MTU : CAN_MTU) #define OPSIZ sizeof(struct bcm_op) #define MHSIZ sizeof(struct bcm_msg_head) @@ -873,6 +892,10 @@ static int bcm_tx_setup(struct bcm_msg_head *msg_head, struct msghdr *msg, if (msg_head->nframes < 1 || msg_head->nframes > MAX_NFRAMES) return -EINVAL; + /* check timeval limitations */ + if ((msg_head->flags & SETTIMER) && bcm_is_invalid_tv(msg_head)) + return -EINVAL; + /* check the given can_id */ op = bcm_find_op(&bo->tx_ops, msg_head, ifindex); if (op) { @@ -1053,6 +1076,10 @@ static int bcm_rx_setup(struct bcm_msg_head *msg_head, struct msghdr *msg, (!(msg_head->can_id & CAN_RTR_FLAG)))) return -EINVAL; + /* check timeval limitations */ + if ((msg_head->flags & SETTIMER) && bcm_is_invalid_tv(msg_head)) + return -EINVAL; + /* check the given can_id */ op = bcm_find_op(&bo->rx_ops, msg_head, ifindex); if (op) { diff --git a/net/ceph/messenger.c b/net/ceph/messenger.c index d5718284db57..7e71b0df1fbc 100644 --- a/net/ceph/messenger.c +++ b/net/ceph/messenger.c @@ -2058,6 +2058,8 @@ static int process_connect(struct ceph_connection *con) dout("process_connect on %p tag %d\n", con, (int)con->in_tag); if (con->auth) { + int len = le32_to_cpu(con->in_reply.authorizer_len); + /* * Any connection that defines ->get_authorizer() * should also define ->add_authorizer_challenge() and @@ -2067,8 +2069,7 @@ static int process_connect(struct ceph_connection *con) */ if (con->in_reply.tag == CEPH_MSGR_TAG_CHALLENGE_AUTHORIZER) { ret = con->ops->add_authorizer_challenge( - con, con->auth->authorizer_reply_buf, - le32_to_cpu(con->in_reply.authorizer_len)); + con, con->auth->authorizer_reply_buf, len); if (ret < 0) return ret; @@ -2078,10 +2079,12 @@ static int process_connect(struct ceph_connection *con) return 0; } - ret = con->ops->verify_authorizer_reply(con); - if (ret < 0) { - con->error_msg = "bad authorize reply"; - return ret; + if (len) { + ret = con->ops->verify_authorizer_reply(con); + if (ret < 0) { + con->error_msg = "bad authorize reply"; + return ret; + } } } @@ -3206,9 +3209,10 @@ void ceph_con_keepalive(struct ceph_connection *con) dout("con_keepalive %p\n", con); mutex_lock(&con->mutex); clear_standby(con); + con_flag_set(con, CON_FLAG_KEEPALIVE_PENDING); mutex_unlock(&con->mutex); - if (con_flag_test_and_set(con, CON_FLAG_KEEPALIVE_PENDING) == 0 && - con_flag_test_and_set(con, CON_FLAG_WRITE_PENDING) == 0) + + if (con_flag_test_and_set(con, CON_FLAG_WRITE_PENDING) == 0) queue_con(con); } EXPORT_SYMBOL(ceph_con_keepalive); diff --git a/net/ceph/osdmap.c b/net/ceph/osdmap.c index 98c0ff3d6441..48a31dc9161c 100644 --- a/net/ceph/osdmap.c +++ b/net/ceph/osdmap.c @@ -495,9 +495,8 @@ static struct crush_map *crush_decode(void *pbyval, void *end) / sizeof(struct crush_rule_step)) goto bad; #endif - r = c->rules[i] = kmalloc(sizeof(*r) + - yes*sizeof(struct crush_rule_step), - GFP_NOFS); + r = kmalloc(struct_size(r, steps, yes), GFP_NOFS); + c->rules[i] = r; if (r == NULL) goto badmem; dout(" rule %d is at %p\n", i, r); diff --git a/net/compat.c b/net/compat.c index 959d1c51826d..eeea5eb71639 100644 --- a/net/compat.c +++ b/net/compat.c @@ -209,8 +209,8 @@ int put_cmsg_compat(struct msghdr *kmsg, int level, int type, int len, void *dat { struct compat_cmsghdr __user *cm = (struct compat_cmsghdr __user *) kmsg->msg_control; struct compat_cmsghdr cmhdr; - struct compat_timeval ctv; - struct compat_timespec cts[3]; + struct old_timeval32 ctv; + struct old_timespec32 cts[3]; int cmlen; if (cm == NULL || kmsg->msg_controllen < sizeof(*cm)) { @@ -219,16 +219,16 @@ int put_cmsg_compat(struct msghdr *kmsg, int level, int type, int len, void *dat } if (!COMPAT_USE_64BIT_TIME) { - if (level == SOL_SOCKET && type == SCM_TIMESTAMP) { - struct timeval *tv = (struct timeval *)data; + if (level == SOL_SOCKET && type == SO_TIMESTAMP_OLD) { + struct __kernel_old_timeval *tv = (struct __kernel_old_timeval *)data; ctv.tv_sec = tv->tv_sec; ctv.tv_usec = tv->tv_usec; data = &ctv; len = sizeof(ctv); } if (level == SOL_SOCKET && - (type == SCM_TIMESTAMPNS || type == SCM_TIMESTAMPING)) { - int count = type == SCM_TIMESTAMPNS ? 1 : 3; + (type == SO_TIMESTAMPNS_OLD || type == SO_TIMESTAMPING_OLD)) { + int count = type == SO_TIMESTAMPNS_OLD ? 1 : 3; int i; struct timespec *ts = (struct timespec *)data; for (i = 0; i < count; i++) { @@ -348,28 +348,6 @@ static int do_set_attach_filter(struct socket *sock, int level, int optname, sizeof(struct sock_fprog)); } -static int do_set_sock_timeout(struct socket *sock, int level, - int optname, char __user *optval, unsigned int optlen) -{ - struct compat_timeval __user *up = (struct compat_timeval __user *)optval; - struct timeval ktime; - mm_segment_t old_fs; - int err; - - if (optlen < sizeof(*up)) - return -EINVAL; - if (!access_ok(up, sizeof(*up)) || - __get_user(ktime.tv_sec, &up->tv_sec) || - __get_user(ktime.tv_usec, &up->tv_usec)) - return -EFAULT; - old_fs = get_fs(); - set_fs(KERNEL_DS); - err = sock_setsockopt(sock, level, optname, (char *)&ktime, sizeof(ktime)); - set_fs(old_fs); - - return err; -} - static int compat_sock_setsockopt(struct socket *sock, int level, int optname, char __user *optval, unsigned int optlen) { @@ -377,10 +355,6 @@ static int compat_sock_setsockopt(struct socket *sock, int level, int optname, optname == SO_ATTACH_REUSEPORT_CBPF) return do_set_attach_filter(sock, level, optname, optval, optlen); - if (!COMPAT_USE_64BIT_TIME && - (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO)) - return do_set_sock_timeout(sock, level, optname, optval, optlen); - return sock_setsockopt(sock, level, optname, optval, optlen); } @@ -388,8 +362,12 @@ static int __compat_sys_setsockopt(int fd, int level, int optname, char __user *optval, unsigned int optlen) { int err; - struct socket *sock = sockfd_lookup(fd, &err); + struct socket *sock; + + if (optlen > INT_MAX) + return -EINVAL; + sock = sockfd_lookup(fd, &err); if (sock) { err = security_socket_setsockopt(sock, level, optname); if (err) { @@ -417,44 +395,6 @@ COMPAT_SYSCALL_DEFINE5(setsockopt, int, fd, int, level, int, optname, return __compat_sys_setsockopt(fd, level, optname, optval, optlen); } -static int do_get_sock_timeout(struct socket *sock, int level, int optname, - char __user *optval, int __user *optlen) -{ - struct compat_timeval __user *up; - struct timeval ktime; - mm_segment_t old_fs; - int len, err; - - up = (struct compat_timeval __user *) optval; - if (get_user(len, optlen)) - return -EFAULT; - if (len < sizeof(*up)) - return -EINVAL; - len = sizeof(ktime); - old_fs = get_fs(); - set_fs(KERNEL_DS); - err = sock_getsockopt(sock, level, optname, (char *) &ktime, &len); - set_fs(old_fs); - - if (!err) { - if (put_user(sizeof(*up), optlen) || - !access_ok(up, sizeof(*up)) || - __put_user(ktime.tv_sec, &up->tv_sec) || - __put_user(ktime.tv_usec, &up->tv_usec)) - err = -EFAULT; - } - return err; -} - -static int compat_sock_getsockopt(struct socket *sock, int level, int optname, - char __user *optval, int __user *optlen) -{ - if (!COMPAT_USE_64BIT_TIME && - (optname == SO_RCVTIMEO || optname == SO_SNDTIMEO)) - return do_get_sock_timeout(sock, level, optname, optval, optlen); - return sock_getsockopt(sock, level, optname, optval, optlen); -} - int compat_sock_get_timestamp(struct sock *sk, struct timeval __user *userstamp) { struct compat_timeval __user *ctv; @@ -527,7 +467,7 @@ static int __compat_sys_getsockopt(int fd, int level, int optname, } if (level == SOL_SOCKET) - err = compat_sock_getsockopt(sock, level, + err = sock_getsockopt(sock, level, optname, optval, optlen); else if (sock->ops->compat_getsockopt) err = sock->ops->compat_getsockopt(sock, level, @@ -585,7 +525,7 @@ int compat_mc_setsockopt(struct sock *sock, int level, int optname, case MCAST_JOIN_GROUP: case MCAST_LEAVE_GROUP: { - struct compat_group_req __user *gr32 = (void *)optval; + struct compat_group_req __user *gr32 = (void __user *)optval; struct group_req __user *kgr = compat_alloc_user_space(sizeof(struct group_req)); u32 interface; @@ -606,7 +546,7 @@ int compat_mc_setsockopt(struct sock *sock, int level, int optname, case MCAST_BLOCK_SOURCE: case MCAST_UNBLOCK_SOURCE: { - struct compat_group_source_req __user *gsr32 = (void *)optval; + struct compat_group_source_req __user *gsr32 = (void __user *)optval; struct group_source_req __user *kgsr = compat_alloc_user_space( sizeof(struct group_source_req)); u32 interface; @@ -627,7 +567,7 @@ int compat_mc_setsockopt(struct sock *sock, int level, int optname, } case MCAST_MSFILTER: { - struct compat_group_filter __user *gf32 = (void *)optval; + struct compat_group_filter __user *gf32 = (void __user *)optval; struct group_filter __user *kgf; u32 interface, fmode, numsrc; @@ -665,7 +605,7 @@ int compat_mc_getsockopt(struct sock *sock, int level, int optname, char __user *optval, int __user *optlen, int (*getsockopt)(struct sock *, int, int, char __user *, int __user *)) { - struct compat_group_filter __user *gf32 = (void *)optval; + struct compat_group_filter __user *gf32 = (void __user *)optval; struct group_filter __user *kgf; int __user *koptlen; u32 interface, fmode, numsrc; @@ -822,7 +762,7 @@ COMPAT_SYSCALL_DEFINE5(recvmmsg_time64, int, fd, struct compat_mmsghdr __user *, } #ifdef CONFIG_COMPAT_32BIT_TIME -COMPAT_SYSCALL_DEFINE5(recvmmsg, int, fd, struct compat_mmsghdr __user *, mmsg, +COMPAT_SYSCALL_DEFINE5(recvmmsg_time32, int, fd, struct compat_mmsghdr __user *, mmsg, unsigned int, vlen, unsigned int, flags, struct old_timespec32 __user *, timeout) { diff --git a/net/core/Makefile b/net/core/Makefile index fccd31e0e7f7..f97d6254e564 100644 --- a/net/core/Makefile +++ b/net/core/Makefile @@ -11,7 +11,7 @@ obj-$(CONFIG_SYSCTL) += sysctl_net_core.o obj-y += dev.o ethtool.o dev_addr_lists.o dst.o netevent.o \ neighbour.o rtnetlink.o utils.o link_watch.o filter.o \ sock_diag.o dev_ioctl.o tso.o sock_reuseport.o \ - fib_notifier.o xdp.o + fib_notifier.o xdp.o flow_offload.o obj-y += net-sysfs.o obj-$(CONFIG_PAGE_POOL) += page_pool.o diff --git a/net/core/dev.c b/net/core/dev.c index 82f20022259d..2b67f2aa59dd 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -3421,7 +3421,7 @@ static void qdisc_pkt_len_init(struct sk_buff *skb) /* To get more precise estimation of bytes sent on wire, * we add to pkt_len the headers size of all segments */ - if (shinfo->gso_size) { + if (shinfo->gso_size && skb_transport_header_was_set(skb)) { unsigned int hdr_len; u16 gso_segs = shinfo->gso_segs; @@ -7878,6 +7878,63 @@ int dev_get_phys_port_name(struct net_device *dev, EXPORT_SYMBOL(dev_get_phys_port_name); /** + * dev_get_port_parent_id - Get the device's port parent identifier + * @dev: network device + * @ppid: pointer to a storage for the port's parent identifier + * @recurse: allow/disallow recursion to lower devices + * + * Get the devices's port parent identifier + */ +int dev_get_port_parent_id(struct net_device *dev, + struct netdev_phys_item_id *ppid, + bool recurse) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct netdev_phys_item_id first = { }; + struct net_device *lower_dev; + struct list_head *iter; + int err = -EOPNOTSUPP; + + if (ops->ndo_get_port_parent_id) + return ops->ndo_get_port_parent_id(dev, ppid); + + if (!recurse) + return err; + + netdev_for_each_lower_dev(dev, lower_dev, iter) { + err = dev_get_port_parent_id(lower_dev, ppid, recurse); + if (err) + break; + if (!first.id_len) + first = *ppid; + else if (memcmp(&first, ppid, sizeof(*ppid))) + return -ENODATA; + } + + return err; +} +EXPORT_SYMBOL(dev_get_port_parent_id); + +/** + * netdev_port_same_parent_id - Indicate if two network devices have + * the same port parent identifier + * @a: first network device + * @b: second network device + */ +bool netdev_port_same_parent_id(struct net_device *a, struct net_device *b) +{ + struct netdev_phys_item_id a_id = { }; + struct netdev_phys_item_id b_id = { }; + + if (dev_get_port_parent_id(a, &a_id, true) || + dev_get_port_parent_id(b, &b_id, true)) + return false; + + return netdev_phys_item_id_same(&a_id, &b_id); +} +EXPORT_SYMBOL(netdev_port_same_parent_id); + +/** * dev_change_proto_down - update protocol port state information * @dev: device * @proto_down: new value @@ -7897,6 +7954,25 @@ int dev_change_proto_down(struct net_device *dev, bool proto_down) } EXPORT_SYMBOL(dev_change_proto_down); +/** + * dev_change_proto_down_generic - generic implementation for + * ndo_change_proto_down that sets carrier according to + * proto_down. + * + * @dev: device + * @proto_down: new value + */ +int dev_change_proto_down_generic(struct net_device *dev, bool proto_down) +{ + if (proto_down) + netif_carrier_off(dev); + else + netif_carrier_on(dev); + dev->proto_down = proto_down; + return 0; +} +EXPORT_SYMBOL(dev_change_proto_down_generic); + u32 __dev_xdp_query(struct net_device *dev, bpf_op_t bpf_op, enum bpf_netdev_command cmd) { @@ -7976,35 +8052,41 @@ int dev_change_xdp_fd(struct net_device *dev, struct netlink_ext_ack *extack, enum bpf_netdev_command query; struct bpf_prog *prog = NULL; bpf_op_t bpf_op, bpf_chk; + bool offload; int err; ASSERT_RTNL(); - query = flags & XDP_FLAGS_HW_MODE ? XDP_QUERY_PROG_HW : XDP_QUERY_PROG; + offload = flags & XDP_FLAGS_HW_MODE; + query = offload ? XDP_QUERY_PROG_HW : XDP_QUERY_PROG; bpf_op = bpf_chk = ops->ndo_bpf; - if (!bpf_op && (flags & (XDP_FLAGS_DRV_MODE | XDP_FLAGS_HW_MODE))) + if (!bpf_op && (flags & (XDP_FLAGS_DRV_MODE | XDP_FLAGS_HW_MODE))) { + NL_SET_ERR_MSG(extack, "underlying driver does not support XDP in native mode"); return -EOPNOTSUPP; + } if (!bpf_op || (flags & XDP_FLAGS_SKB_MODE)) bpf_op = generic_xdp_install; if (bpf_op == bpf_chk) bpf_chk = generic_xdp_install; if (fd >= 0) { - if (__dev_xdp_query(dev, bpf_chk, XDP_QUERY_PROG) || - __dev_xdp_query(dev, bpf_chk, XDP_QUERY_PROG_HW)) + if (!offload && __dev_xdp_query(dev, bpf_chk, XDP_QUERY_PROG)) { + NL_SET_ERR_MSG(extack, "native and generic XDP can't be active at the same time"); return -EEXIST; + } if ((flags & XDP_FLAGS_UPDATE_IF_NOEXIST) && - __dev_xdp_query(dev, bpf_op, query)) + __dev_xdp_query(dev, bpf_op, query)) { + NL_SET_ERR_MSG(extack, "XDP program already attached"); return -EBUSY; + } prog = bpf_prog_get_type_dev(fd, BPF_PROG_TYPE_XDP, bpf_op == ops->ndo_bpf); if (IS_ERR(prog)) return PTR_ERR(prog); - if (!(flags & XDP_FLAGS_HW_MODE) && - bpf_prog_is_dev_bound(prog->aux)) { + if (!offload && bpf_prog_is_dev_bound(prog->aux)) { NL_SET_ERR_MSG(extack, "using device-bound program without HW_MODE flag is not supported"); bpf_prog_put(prog); return -EINVAL; @@ -8152,7 +8234,7 @@ static netdev_features_t netdev_sync_upper_features(struct net_device *lower, netdev_features_t feature; int feature_bit; - for_each_netdev_feature(&upper_disables, feature_bit) { + for_each_netdev_feature(upper_disables, feature_bit) { feature = __NETIF_F_BIT(feature_bit); if (!(upper->wanted_features & feature) && (features & feature)) { @@ -8172,7 +8254,7 @@ static void netdev_sync_lower_features(struct net_device *upper, netdev_features_t feature; int feature_bit; - for_each_netdev_feature(&upper_disables, feature_bit) { + for_each_netdev_feature(upper_disables, feature_bit) { feature = __NETIF_F_BIT(feature_bit); if (!(features & feature) && (lower->features & feature)) { netdev_dbg(upper, "Disabling feature %pNF on lower dev %s.\n", @@ -8712,6 +8794,9 @@ int init_dummy_netdev(struct net_device *dev) set_bit(__LINK_STATE_PRESENT, &dev->state); set_bit(__LINK_STATE_START, &dev->state); + /* napi_busy_loop stats accounting wants this */ + dev_net_set(dev, &init_net); + /* Note : We dont allocate pcpu_refcnt for dummy devices, * because users of this 'device' dont need to change * its refcount. diff --git a/net/core/devlink.c b/net/core/devlink.c index abb0da9d7b4b..78e22cea4cc7 100644 --- a/net/core/devlink.c +++ b/net/core/devlink.c @@ -81,6 +81,7 @@ struct devlink_dpipe_header devlink_dpipe_header_ipv6 = { EXPORT_SYMBOL(devlink_dpipe_header_ipv6); EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_hwmsg); +EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_hwerr); static LIST_HEAD(devlink_list); @@ -115,6 +116,8 @@ static struct devlink *devlink_get_from_attrs(struct net *net, busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]); devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); + lockdep_assert_held(&devlink_mutex); + list_for_each_entry(devlink, &devlink_list, list) { if (strcmp(devlink->dev->bus->name, busname) == 0 && strcmp(dev_name(devlink->dev), devname) == 0 && @@ -720,7 +723,7 @@ static int devlink_port_type_set(struct devlink *devlink, { int err; - if (devlink->ops && devlink->ops->port_type_set) { + if (devlink->ops->port_type_set) { if (port_type == DEVLINK_PORT_TYPE_NOTSET) return -EINVAL; if (port_type == devlink_port->type) @@ -757,7 +760,7 @@ static int devlink_port_split(struct devlink *devlink, u32 port_index, u32 count, struct netlink_ext_ack *extack) { - if (devlink->ops && devlink->ops->port_split) + if (devlink->ops->port_split) return devlink->ops->port_split(devlink, port_index, count, extack); return -EOPNOTSUPP; @@ -783,7 +786,7 @@ static int devlink_port_unsplit(struct devlink *devlink, u32 port_index, struct netlink_ext_ack *extack) { - if (devlink->ops && devlink->ops->port_unsplit) + if (devlink->ops->port_unsplit) return devlink->ops->port_unsplit(devlink, port_index, extack); return -EOPNOTSUPP; } @@ -932,6 +935,9 @@ static int devlink_nl_sb_pool_fill(struct sk_buff *msg, struct devlink *devlink, if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE, pool_info.threshold_type)) goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_POOL_CELL_SIZE, + pool_info.cell_size)) + goto nla_put_failure; genlmsg_end(msg, hdr); return 0; @@ -955,7 +961,7 @@ static int devlink_nl_cmd_sb_pool_get_doit(struct sk_buff *skb, if (err) return err; - if (!devlink->ops || !devlink->ops->sb_pool_get) + if (!devlink->ops->sb_pool_get) return -EOPNOTSUPP; msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); @@ -1011,7 +1017,7 @@ static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg, mutex_lock(&devlink_mutex); list_for_each_entry(devlink, &devlink_list, list) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || - !devlink->ops || !devlink->ops->sb_pool_get) + !devlink->ops->sb_pool_get) continue; mutex_lock(&devlink->lock); list_for_each_entry(devlink_sb, &devlink->sb_list, list) { @@ -1040,7 +1046,7 @@ static int devlink_sb_pool_set(struct devlink *devlink, unsigned int sb_index, { const struct devlink_ops *ops = devlink->ops; - if (ops && ops->sb_pool_set) + if (ops->sb_pool_set) return ops->sb_pool_set(devlink, sb_index, pool_index, size, threshold_type); return -EOPNOTSUPP; @@ -1145,7 +1151,7 @@ static int devlink_nl_cmd_sb_port_pool_get_doit(struct sk_buff *skb, if (err) return err; - if (!devlink->ops || !devlink->ops->sb_port_pool_get) + if (!devlink->ops->sb_port_pool_get) return -EOPNOTSUPP; msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); @@ -1207,7 +1213,7 @@ static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg, mutex_lock(&devlink_mutex); list_for_each_entry(devlink, &devlink_list, list) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || - !devlink->ops || !devlink->ops->sb_port_pool_get) + !devlink->ops->sb_port_pool_get) continue; mutex_lock(&devlink->lock); list_for_each_entry(devlink_sb, &devlink->sb_list, list) { @@ -1236,7 +1242,7 @@ static int devlink_sb_port_pool_set(struct devlink_port *devlink_port, { const struct devlink_ops *ops = devlink_port->devlink->ops; - if (ops && ops->sb_port_pool_set) + if (ops->sb_port_pool_set) return ops->sb_port_pool_set(devlink_port, sb_index, pool_index, threshold); return -EOPNOTSUPP; @@ -1349,7 +1355,7 @@ static int devlink_nl_cmd_sb_tc_pool_bind_get_doit(struct sk_buff *skb, if (err) return err; - if (!devlink->ops || !devlink->ops->sb_tc_pool_bind_get) + if (!devlink->ops->sb_tc_pool_bind_get) return -EOPNOTSUPP; msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); @@ -1433,7 +1439,7 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg, mutex_lock(&devlink_mutex); list_for_each_entry(devlink, &devlink_list, list) { if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) || - !devlink->ops || !devlink->ops->sb_tc_pool_bind_get) + !devlink->ops->sb_tc_pool_bind_get) continue; mutex_lock(&devlink->lock); @@ -1465,7 +1471,7 @@ static int devlink_sb_tc_pool_bind_set(struct devlink_port *devlink_port, { const struct devlink_ops *ops = devlink_port->devlink->ops; - if (ops && ops->sb_tc_pool_bind_set) + if (ops->sb_tc_pool_bind_set) return ops->sb_tc_pool_bind_set(devlink_port, sb_index, tc_index, pool_type, pool_index, threshold); @@ -1513,7 +1519,7 @@ static int devlink_nl_cmd_sb_occ_snapshot_doit(struct sk_buff *skb, struct devlink_sb *devlink_sb = info->user_ptr[1]; const struct devlink_ops *ops = devlink->ops; - if (ops && ops->sb_occ_snapshot) + if (ops->sb_occ_snapshot) return ops->sb_occ_snapshot(devlink, devlink_sb->index); return -EOPNOTSUPP; } @@ -1525,7 +1531,7 @@ static int devlink_nl_cmd_sb_occ_max_clear_doit(struct sk_buff *skb, struct devlink_sb *devlink_sb = info->user_ptr[1]; const struct devlink_ops *ops = devlink->ops; - if (ops && ops->sb_occ_max_clear) + if (ops->sb_occ_max_clear) return ops->sb_occ_max_clear(devlink, devlink_sb->index); return -EOPNOTSUPP; } @@ -1588,13 +1594,9 @@ static int devlink_nl_cmd_eswitch_get_doit(struct sk_buff *skb, struct genl_info *info) { struct devlink *devlink = info->user_ptr[0]; - const struct devlink_ops *ops = devlink->ops; struct sk_buff *msg; int err; - if (!ops) - return -EOPNOTSUPP; - msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!msg) return -ENOMEM; @@ -1619,9 +1621,6 @@ static int devlink_nl_cmd_eswitch_set_doit(struct sk_buff *skb, int err = 0; u16 mode; - if (!ops) - return -EOPNOTSUPP; - if (info->attrs[DEVLINK_ATTR_ESWITCH_MODE]) { if (!ops->eswitch_mode_set) return -EOPNOTSUPP; @@ -2656,6 +2655,27 @@ static int devlink_nl_cmd_reload(struct sk_buff *skb, struct genl_info *info) return devlink->ops->reload(devlink, info->extack); } +static int devlink_nl_cmd_flash_update(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + const char *file_name, *component; + struct nlattr *nla_component; + + if (!devlink->ops->flash_update) + return -EOPNOTSUPP; + + if (!info->attrs[DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME]) + return -EINVAL; + file_name = nla_data(info->attrs[DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME]); + + nla_component = info->attrs[DEVLINK_ATTR_FLASH_UPDATE_COMPONENT]; + component = nla_component ? nla_data(nla_component) : NULL; + + return devlink->ops->flash_update(devlink, file_name, component, + info->extack); +} + static const struct devlink_param devlink_param_generic[] = { { .id = DEVLINK_PARAM_GENERIC_ID_INT_ERR_RESET, @@ -2843,11 +2863,13 @@ nla_put_failure: } static int devlink_nl_param_fill(struct sk_buff *msg, struct devlink *devlink, + unsigned int port_index, struct devlink_param_item *param_item, enum devlink_command cmd, u32 portid, u32 seq, int flags) { union devlink_param_value param_value[DEVLINK_PARAM_CMODE_MAX + 1]; + bool param_value_set[DEVLINK_PARAM_CMODE_MAX + 1] = {}; const struct devlink_param *param = param_item->param; struct devlink_param_gset_ctx ctx; struct nlattr *param_values_list; @@ -2866,12 +2888,15 @@ static int devlink_nl_param_fill(struct sk_buff *msg, struct devlink *devlink, return -EOPNOTSUPP; param_value[i] = param_item->driverinit_value; } else { + if (!param_item->published) + continue; ctx.cmode = i; err = devlink_param_get(devlink, param, &ctx); if (err) return err; param_value[i] = ctx.val; } + param_value_set[i] = true; } hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); @@ -2880,6 +2905,13 @@ static int devlink_nl_param_fill(struct sk_buff *msg, struct devlink *devlink, if (devlink_nl_put_handle(msg, devlink)) goto genlmsg_cancel; + + if (cmd == DEVLINK_CMD_PORT_PARAM_GET || + cmd == DEVLINK_CMD_PORT_PARAM_NEW || + cmd == DEVLINK_CMD_PORT_PARAM_DEL) + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, port_index)) + goto genlmsg_cancel; + param_attr = nla_nest_start(msg, DEVLINK_ATTR_PARAM); if (!param_attr) goto genlmsg_cancel; @@ -2899,7 +2931,7 @@ static int devlink_nl_param_fill(struct sk_buff *msg, struct devlink *devlink, goto param_nest_cancel; for (i = 0; i <= DEVLINK_PARAM_CMODE_MAX; i++) { - if (!devlink_param_cmode_is_supported(param, i)) + if (!param_value_set[i]) continue; err = devlink_nl_param_value_fill_one(msg, param->type, i, param_value[i]); @@ -2922,18 +2954,22 @@ genlmsg_cancel: } static void devlink_param_notify(struct devlink *devlink, + unsigned int port_index, struct devlink_param_item *param_item, enum devlink_command cmd) { struct sk_buff *msg; int err; - WARN_ON(cmd != DEVLINK_CMD_PARAM_NEW && cmd != DEVLINK_CMD_PARAM_DEL); + WARN_ON(cmd != DEVLINK_CMD_PARAM_NEW && cmd != DEVLINK_CMD_PARAM_DEL && + cmd != DEVLINK_CMD_PORT_PARAM_NEW && + cmd != DEVLINK_CMD_PORT_PARAM_DEL); msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!msg) return; - err = devlink_nl_param_fill(msg, devlink, param_item, cmd, 0, 0, 0); + err = devlink_nl_param_fill(msg, devlink, port_index, param_item, cmd, + 0, 0, 0); if (err) { nlmsg_free(msg); return; @@ -2962,7 +2998,7 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg, idx++; continue; } - err = devlink_nl_param_fill(msg, devlink, param_item, + err = devlink_nl_param_fill(msg, devlink, 0, param_item, DEVLINK_CMD_PARAM_GET, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, @@ -3051,7 +3087,7 @@ devlink_param_value_get_from_info(const struct devlink_param *param, } static struct devlink_param_item * -devlink_param_get_from_info(struct devlink *devlink, +devlink_param_get_from_info(struct list_head *param_list, struct genl_info *info) { char *param_name; @@ -3060,7 +3096,7 @@ devlink_param_get_from_info(struct devlink *devlink, return NULL; param_name = nla_data(info->attrs[DEVLINK_ATTR_PARAM_NAME]); - return devlink_param_find_by_name(&devlink->param_list, param_name); + return devlink_param_find_by_name(param_list, param_name); } static int devlink_nl_cmd_param_get_doit(struct sk_buff *skb, @@ -3071,7 +3107,7 @@ static int devlink_nl_cmd_param_get_doit(struct sk_buff *skb, struct sk_buff *msg; int err; - param_item = devlink_param_get_from_info(devlink, info); + param_item = devlink_param_get_from_info(&devlink->param_list, info); if (!param_item) return -EINVAL; @@ -3079,7 +3115,7 @@ static int devlink_nl_cmd_param_get_doit(struct sk_buff *skb, if (!msg) return -ENOMEM; - err = devlink_nl_param_fill(msg, devlink, param_item, + err = devlink_nl_param_fill(msg, devlink, 0, param_item, DEVLINK_CMD_PARAM_GET, info->snd_portid, info->snd_seq, 0); if (err) { @@ -3090,10 +3126,12 @@ static int devlink_nl_cmd_param_get_doit(struct sk_buff *skb, return genlmsg_reply(msg, info); } -static int devlink_nl_cmd_param_set_doit(struct sk_buff *skb, - struct genl_info *info) +static int __devlink_nl_cmd_param_set_doit(struct devlink *devlink, + unsigned int port_index, + struct list_head *param_list, + struct genl_info *info, + enum devlink_command cmd) { - struct devlink *devlink = info->user_ptr[0]; enum devlink_param_type param_type; struct devlink_param_gset_ctx ctx; enum devlink_param_cmode cmode; @@ -3102,7 +3140,7 @@ static int devlink_nl_cmd_param_set_doit(struct sk_buff *skb, union devlink_param_value value; int err = 0; - param_item = devlink_param_get_from_info(devlink, info); + param_item = devlink_param_get_from_info(param_list, info); if (!param_item) return -EINVAL; param = param_item->param; @@ -3142,17 +3180,28 @@ static int devlink_nl_cmd_param_set_doit(struct sk_buff *skb, return err; } - devlink_param_notify(devlink, param_item, DEVLINK_CMD_PARAM_NEW); + devlink_param_notify(devlink, port_index, param_item, cmd); return 0; } +static int devlink_nl_cmd_param_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + + return __devlink_nl_cmd_param_set_doit(devlink, 0, &devlink->param_list, + info, DEVLINK_CMD_PARAM_NEW); +} + static int devlink_param_register_one(struct devlink *devlink, - const struct devlink_param *param) + unsigned int port_index, + struct list_head *param_list, + const struct devlink_param *param, + enum devlink_command cmd) { struct devlink_param_item *param_item; - if (devlink_param_find_by_name(&devlink->param_list, - param->name)) + if (devlink_param_find_by_name(param_list, param->name)) return -EEXIST; if (param->supported_cmodes == BIT(DEVLINK_PARAM_CMODE_DRIVERINIT)) @@ -3165,24 +3214,111 @@ static int devlink_param_register_one(struct devlink *devlink, return -ENOMEM; param_item->param = param; - list_add_tail(¶m_item->list, &devlink->param_list); - devlink_param_notify(devlink, param_item, DEVLINK_CMD_PARAM_NEW); + list_add_tail(¶m_item->list, param_list); + devlink_param_notify(devlink, port_index, param_item, cmd); return 0; } static void devlink_param_unregister_one(struct devlink *devlink, - const struct devlink_param *param) + unsigned int port_index, + struct list_head *param_list, + const struct devlink_param *param, + enum devlink_command cmd) { struct devlink_param_item *param_item; - param_item = devlink_param_find_by_name(&devlink->param_list, - param->name); + param_item = devlink_param_find_by_name(param_list, param->name); WARN_ON(!param_item); - devlink_param_notify(devlink, param_item, DEVLINK_CMD_PARAM_DEL); + devlink_param_notify(devlink, port_index, param_item, cmd); list_del(¶m_item->list); kfree(param_item); } +static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_param_item *param_item; + struct devlink_port *devlink_port; + struct devlink *devlink; + int start = cb->args[0]; + int idx = 0; + int err; + + mutex_lock(&devlink_mutex); + list_for_each_entry(devlink, &devlink_list, list) { + if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) + continue; + mutex_lock(&devlink->lock); + list_for_each_entry(devlink_port, &devlink->port_list, list) { + list_for_each_entry(param_item, + &devlink_port->param_list, list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_param_fill(msg, + devlink_port->devlink, + devlink_port->index, param_item, + DEVLINK_CMD_PORT_PARAM_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + mutex_unlock(&devlink->lock); + goto out; + } + idx++; + } + } + mutex_unlock(&devlink->lock); + } +out: + mutex_unlock(&devlink_mutex); + + cb->args[0] = idx; + return msg->len; +} + +static int devlink_nl_cmd_port_param_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[0]; + struct devlink_param_item *param_item; + struct sk_buff *msg; + int err; + + param_item = devlink_param_get_from_info(&devlink_port->param_list, + info); + if (!param_item) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_param_fill(msg, devlink_port->devlink, + devlink_port->index, param_item, + DEVLINK_CMD_PORT_PARAM_GET, + info->snd_portid, info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_port_param_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[0]; + + return __devlink_nl_cmd_param_set_doit(devlink_port->devlink, + devlink_port->index, + &devlink_port->param_list, info, + DEVLINK_CMD_PORT_PARAM_NEW); +} + static int devlink_nl_region_snapshot_id_put(struct sk_buff *msg, struct devlink *devlink, struct devlink_snapshot *snapshot) @@ -3504,44 +3640,56 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, struct netlink_callback *cb) { u64 ret_offset, start_offset, end_offset = 0; - struct nlattr *attrs[DEVLINK_ATTR_MAX + 1]; const struct genl_ops *ops = cb->data; struct devlink_region *region; struct nlattr *chunks_attr; const char *region_name; struct devlink *devlink; + struct nlattr **attrs; bool dump = true; void *hdr; int err; start_offset = *((u64 *)&cb->args[0]); + attrs = kmalloc_array(DEVLINK_ATTR_MAX + 1, sizeof(*attrs), GFP_KERNEL); + if (!attrs) + return -ENOMEM; + err = nlmsg_parse(cb->nlh, GENL_HDRLEN + devlink_nl_family.hdrsize, attrs, DEVLINK_ATTR_MAX, ops->policy, cb->extack); if (err) - goto out; + goto out_free; + mutex_lock(&devlink_mutex); devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); - if (IS_ERR(devlink)) - goto out; + if (IS_ERR(devlink)) { + err = PTR_ERR(devlink); + goto out_dev; + } - mutex_lock(&devlink_mutex); mutex_lock(&devlink->lock); if (!attrs[DEVLINK_ATTR_REGION_NAME] || - !attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]) + !attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]) { + err = -EINVAL; goto out_unlock; + } region_name = nla_data(attrs[DEVLINK_ATTR_REGION_NAME]); region = devlink_region_get_by_name(devlink, region_name); - if (!region) + if (!region) { + err = -EINVAL; goto out_unlock; + } hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, &devlink_nl_family, NLM_F_ACK | NLM_F_MULTI, DEVLINK_CMD_REGION_READ); - if (!hdr) + if (!hdr) { + err = -EMSGSIZE; goto out_unlock; + } err = devlink_nl_put_handle(skb, devlink); if (err) @@ -3552,8 +3700,10 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, goto nla_put_failure; chunks_attr = nla_nest_start(skb, DEVLINK_ATTR_REGION_CHUNKS); - if (!chunks_attr) + if (!chunks_attr) { + err = -EMSGSIZE; goto nla_put_failure; + } if (attrs[DEVLINK_ATTR_REGION_CHUNK_ADDR] && attrs[DEVLINK_ATTR_REGION_CHUNK_LEN]) { @@ -3576,8 +3726,10 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, goto nla_put_failure; /* Check if there was any progress done to prevent infinite loop */ - if (ret_offset == start_offset) + if (ret_offset == start_offset) { + err = -EINVAL; goto nla_put_failure; + } *((u64 *)&cb->args[0]) = ret_offset; @@ -3585,6 +3737,7 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, genlmsg_end(skb, hdr); mutex_unlock(&devlink->lock); mutex_unlock(&devlink_mutex); + kfree(attrs); return skb->len; @@ -3592,8 +3745,1144 @@ nla_put_failure: genlmsg_cancel(skb, hdr); out_unlock: mutex_unlock(&devlink->lock); +out_dev: + mutex_unlock(&devlink_mutex); +out_free: + kfree(attrs); + return err; +} + +struct devlink_info_req { + struct sk_buff *msg; +}; + +int devlink_info_driver_name_put(struct devlink_info_req *req, const char *name) +{ + return nla_put_string(req->msg, DEVLINK_ATTR_INFO_DRIVER_NAME, name); +} +EXPORT_SYMBOL_GPL(devlink_info_driver_name_put); + +int devlink_info_serial_number_put(struct devlink_info_req *req, const char *sn) +{ + return nla_put_string(req->msg, DEVLINK_ATTR_INFO_SERIAL_NUMBER, sn); +} +EXPORT_SYMBOL_GPL(devlink_info_serial_number_put); + +static int devlink_info_version_put(struct devlink_info_req *req, int attr, + const char *version_name, + const char *version_value) +{ + struct nlattr *nest; + int err; + + nest = nla_nest_start(req->msg, attr); + if (!nest) + return -EMSGSIZE; + + err = nla_put_string(req->msg, DEVLINK_ATTR_INFO_VERSION_NAME, + version_name); + if (err) + goto nla_put_failure; + + err = nla_put_string(req->msg, DEVLINK_ATTR_INFO_VERSION_VALUE, + version_value); + if (err) + goto nla_put_failure; + + nla_nest_end(req->msg, nest); + + return 0; + +nla_put_failure: + nla_nest_cancel(req->msg, nest); + return err; +} + +int devlink_info_version_fixed_put(struct devlink_info_req *req, + const char *version_name, + const char *version_value) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_FIXED, + version_name, version_value); +} +EXPORT_SYMBOL_GPL(devlink_info_version_fixed_put); + +int devlink_info_version_stored_put(struct devlink_info_req *req, + const char *version_name, + const char *version_value) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_STORED, + version_name, version_value); +} +EXPORT_SYMBOL_GPL(devlink_info_version_stored_put); + +int devlink_info_version_running_put(struct devlink_info_req *req, + const char *version_name, + const char *version_value) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_RUNNING, + version_name, version_value); +} +EXPORT_SYMBOL_GPL(devlink_info_version_running_put); + +static int +devlink_nl_info_fill(struct sk_buff *msg, struct devlink *devlink, + enum devlink_command cmd, u32 portid, + u32 seq, int flags, struct netlink_ext_ack *extack) +{ + struct devlink_info_req req; + void *hdr; + int err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + err = -EMSGSIZE; + if (devlink_nl_put_handle(msg, devlink)) + goto err_cancel_msg; + + req.msg = msg; + err = devlink->ops->info_get(devlink, &req, extack); + if (err) + goto err_cancel_msg; + + genlmsg_end(msg, hdr); + return 0; + +err_cancel_msg: + genlmsg_cancel(msg, hdr); + return err; +} + +static int devlink_nl_cmd_info_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct sk_buff *msg; + int err; + + if (!devlink->ops->info_get) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_info_fill(msg, devlink, DEVLINK_CMD_INFO_GET, + info->snd_portid, info->snd_seq, 0, + info->extack); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + int start = cb->args[0]; + int idx = 0; + int err; + + mutex_lock(&devlink_mutex); + list_for_each_entry(devlink, &devlink_list, list) { + if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) + continue; + if (idx < start) { + idx++; + continue; + } + + mutex_lock(&devlink->lock); + err = devlink_nl_info_fill(msg, devlink, DEVLINK_CMD_INFO_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + cb->extack); + mutex_unlock(&devlink->lock); + if (err) + break; + idx++; + } + mutex_unlock(&devlink_mutex); + + cb->args[0] = idx; + return msg->len; +} + +struct devlink_fmsg_item { + struct list_head list; + int attrtype; + u8 nla_type; + u16 len; + int value[0]; +}; + +struct devlink_fmsg { + struct list_head item_list; +}; + +static struct devlink_fmsg *devlink_fmsg_alloc(void) +{ + struct devlink_fmsg *fmsg; + + fmsg = kzalloc(sizeof(*fmsg), GFP_KERNEL); + if (!fmsg) + return NULL; + + INIT_LIST_HEAD(&fmsg->item_list); + + return fmsg; +} + +static void devlink_fmsg_free(struct devlink_fmsg *fmsg) +{ + struct devlink_fmsg_item *item, *tmp; + + list_for_each_entry_safe(item, tmp, &fmsg->item_list, list) { + list_del(&item->list); + kfree(item); + } + kfree(fmsg); +} + +static int devlink_fmsg_nest_common(struct devlink_fmsg *fmsg, + int attrtype) +{ + struct devlink_fmsg_item *item; + + item = kzalloc(sizeof(*item), GFP_KERNEL); + if (!item) + return -ENOMEM; + + item->attrtype = attrtype; + list_add_tail(&item->list, &fmsg->item_list); + + return 0; +} + +int devlink_fmsg_obj_nest_start(struct devlink_fmsg *fmsg) +{ + return devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_OBJ_NEST_START); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_obj_nest_start); + +static int devlink_fmsg_nest_end(struct devlink_fmsg *fmsg) +{ + return devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_NEST_END); +} + +int devlink_fmsg_obj_nest_end(struct devlink_fmsg *fmsg) +{ + return devlink_fmsg_nest_end(fmsg); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_obj_nest_end); + +#define DEVLINK_FMSG_MAX_SIZE (GENLMSG_DEFAULT_SIZE - GENL_HDRLEN - NLA_HDRLEN) + +static int devlink_fmsg_put_name(struct devlink_fmsg *fmsg, const char *name) +{ + struct devlink_fmsg_item *item; + + if (strlen(name) + 1 > DEVLINK_FMSG_MAX_SIZE) + return -EMSGSIZE; + + item = kzalloc(sizeof(*item) + strlen(name) + 1, GFP_KERNEL); + if (!item) + return -ENOMEM; + + item->nla_type = NLA_NUL_STRING; + item->len = strlen(name) + 1; + item->attrtype = DEVLINK_ATTR_FMSG_OBJ_NAME; + memcpy(&item->value, name, item->len); + list_add_tail(&item->list, &fmsg->item_list); + + return 0; +} + +int devlink_fmsg_pair_nest_start(struct devlink_fmsg *fmsg, const char *name) +{ + int err; + + err = devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_PAIR_NEST_START); + if (err) + return err; + + err = devlink_fmsg_put_name(fmsg, name); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_pair_nest_start); + +int devlink_fmsg_pair_nest_end(struct devlink_fmsg *fmsg) +{ + return devlink_fmsg_nest_end(fmsg); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_pair_nest_end); + +int devlink_fmsg_arr_pair_nest_start(struct devlink_fmsg *fmsg, + const char *name) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_ARR_NEST_START); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_arr_pair_nest_start); + +int devlink_fmsg_arr_pair_nest_end(struct devlink_fmsg *fmsg) +{ + int err; + + err = devlink_fmsg_nest_end(fmsg); + if (err) + return err; + + err = devlink_fmsg_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_arr_pair_nest_end); + +static int devlink_fmsg_put_value(struct devlink_fmsg *fmsg, + const void *value, u16 value_len, + u8 value_nla_type) +{ + struct devlink_fmsg_item *item; + + if (value_len > DEVLINK_FMSG_MAX_SIZE) + return -EMSGSIZE; + + item = kzalloc(sizeof(*item) + value_len, GFP_KERNEL); + if (!item) + return -ENOMEM; + + item->nla_type = value_nla_type; + item->len = value_len; + item->attrtype = DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA; + memcpy(&item->value, value, item->len); + list_add_tail(&item->list, &fmsg->item_list); + + return 0; +} + +int devlink_fmsg_bool_put(struct devlink_fmsg *fmsg, bool value) +{ + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_FLAG); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_bool_put); + +int devlink_fmsg_u8_put(struct devlink_fmsg *fmsg, u8 value) +{ + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_U8); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u8_put); + +int devlink_fmsg_u32_put(struct devlink_fmsg *fmsg, u32 value) +{ + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_U32); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u32_put); + +int devlink_fmsg_u64_put(struct devlink_fmsg *fmsg, u64 value) +{ + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_U64); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u64_put); + +int devlink_fmsg_string_put(struct devlink_fmsg *fmsg, const char *value) +{ + return devlink_fmsg_put_value(fmsg, value, strlen(value) + 1, + NLA_NUL_STRING); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_string_put); + +int devlink_fmsg_binary_put(struct devlink_fmsg *fmsg, const void *value, + u16 value_len) +{ + return devlink_fmsg_put_value(fmsg, value, value_len, NLA_BINARY); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_binary_put); + +int devlink_fmsg_bool_pair_put(struct devlink_fmsg *fmsg, const char *name, + bool value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_bool_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_bool_pair_put); + +int devlink_fmsg_u8_pair_put(struct devlink_fmsg *fmsg, const char *name, + u8 value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_u8_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u8_pair_put); + +int devlink_fmsg_u32_pair_put(struct devlink_fmsg *fmsg, const char *name, + u32 value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_u32_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u32_pair_put); + +int devlink_fmsg_u64_pair_put(struct devlink_fmsg *fmsg, const char *name, + u64 value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_u64_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u64_pair_put); + +int devlink_fmsg_string_pair_put(struct devlink_fmsg *fmsg, const char *name, + const char *value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_string_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_string_pair_put); + +int devlink_fmsg_binary_pair_put(struct devlink_fmsg *fmsg, const char *name, + const void *value, u16 value_len) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_binary_put(fmsg, value, value_len); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_binary_pair_put); + +static int +devlink_fmsg_item_fill_type(struct devlink_fmsg_item *msg, struct sk_buff *skb) +{ + switch (msg->nla_type) { + case NLA_FLAG: + case NLA_U8: + case NLA_U32: + case NLA_U64: + case NLA_NUL_STRING: + case NLA_BINARY: + return nla_put_u8(skb, DEVLINK_ATTR_FMSG_OBJ_VALUE_TYPE, + msg->nla_type); + default: + return -EINVAL; + } +} + +static int +devlink_fmsg_item_fill_data(struct devlink_fmsg_item *msg, struct sk_buff *skb) +{ + int attrtype = DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA; + u8 tmp; + + switch (msg->nla_type) { + case NLA_FLAG: + /* Always provide flag data, regardless of its value */ + tmp = *(bool *) msg->value; + + return nla_put_u8(skb, attrtype, tmp); + case NLA_U8: + return nla_put_u8(skb, attrtype, *(u8 *) msg->value); + case NLA_U32: + return nla_put_u32(skb, attrtype, *(u32 *) msg->value); + case NLA_U64: + return nla_put_u64_64bit(skb, attrtype, *(u64 *) msg->value, + DEVLINK_ATTR_PAD); + case NLA_NUL_STRING: + return nla_put_string(skb, attrtype, (char *) &msg->value); + case NLA_BINARY: + return nla_put(skb, attrtype, msg->len, (void *) &msg->value); + default: + return -EINVAL; + } +} + +static int +devlink_fmsg_prepare_skb(struct devlink_fmsg *fmsg, struct sk_buff *skb, + int *start) +{ + struct devlink_fmsg_item *item; + struct nlattr *fmsg_nlattr; + int i = 0; + int err; + + fmsg_nlattr = nla_nest_start(skb, DEVLINK_ATTR_FMSG); + if (!fmsg_nlattr) + return -EMSGSIZE; + + list_for_each_entry(item, &fmsg->item_list, list) { + if (i < *start) { + i++; + continue; + } + + switch (item->attrtype) { + case DEVLINK_ATTR_FMSG_OBJ_NEST_START: + case DEVLINK_ATTR_FMSG_PAIR_NEST_START: + case DEVLINK_ATTR_FMSG_ARR_NEST_START: + case DEVLINK_ATTR_FMSG_NEST_END: + err = nla_put_flag(skb, item->attrtype); + break; + case DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA: + err = devlink_fmsg_item_fill_type(item, skb); + if (err) + break; + err = devlink_fmsg_item_fill_data(item, skb); + break; + case DEVLINK_ATTR_FMSG_OBJ_NAME: + err = nla_put_string(skb, item->attrtype, + (char *) &item->value); + break; + default: + err = -EINVAL; + break; + } + if (!err) + *start = ++i; + else + break; + } + + nla_nest_end(skb, fmsg_nlattr); + return err; +} + +static int devlink_fmsg_snd(struct devlink_fmsg *fmsg, + struct genl_info *info, + enum devlink_command cmd, int flags) +{ + struct nlmsghdr *nlh; + struct sk_buff *skb; + bool last = false; + int index = 0; + void *hdr; + int err; + + while (!last) { + int tmp_index = index; + + skb = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, info->snd_portid, info->snd_seq, + &devlink_nl_family, flags | NLM_F_MULTI, cmd); + if (!hdr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = devlink_fmsg_prepare_skb(fmsg, skb, &index); + if (!err) + last = true; + else if (err != -EMSGSIZE || tmp_index == index) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + err = genlmsg_reply(skb, info); + if (err) + return err; + } + + skb = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return -ENOMEM; + nlh = nlmsg_put(skb, info->snd_portid, info->snd_seq, + NLMSG_DONE, 0, flags | NLM_F_MULTI); + if (!nlh) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + return genlmsg_reply(skb, info); + +nla_put_failure: + nlmsg_free(skb); + return err; +} + +struct devlink_health_reporter { + struct list_head list; + void *priv; + const struct devlink_health_reporter_ops *ops; + struct devlink *devlink; + struct devlink_fmsg *dump_fmsg; + struct mutex dump_lock; /* lock parallel read/write from dump buffers */ + u64 graceful_period; + bool auto_recover; + u8 health_state; + u64 dump_ts; + u64 error_count; + u64 recovery_count; + u64 last_recovery_ts; +}; + +void * +devlink_health_reporter_priv(struct devlink_health_reporter *reporter) +{ + return reporter->priv; +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_priv); + +static struct devlink_health_reporter * +devlink_health_reporter_find_by_name(struct devlink *devlink, + const char *reporter_name) +{ + struct devlink_health_reporter *reporter; + + list_for_each_entry(reporter, &devlink->reporter_list, list) + if (!strcmp(reporter->ops->name, reporter_name)) + return reporter; + return NULL; +} + +/** + * devlink_health_reporter_create - create devlink health reporter + * + * @devlink: devlink + * @ops: ops + * @graceful_period: to avoid recovery loops, in msecs + * @auto_recover: auto recover when error occurs + * @priv: priv + */ +struct devlink_health_reporter * +devlink_health_reporter_create(struct devlink *devlink, + const struct devlink_health_reporter_ops *ops, + u64 graceful_period, bool auto_recover, + void *priv) +{ + struct devlink_health_reporter *reporter; + + mutex_lock(&devlink->lock); + if (devlink_health_reporter_find_by_name(devlink, ops->name)) { + reporter = ERR_PTR(-EEXIST); + goto unlock; + } + + if (WARN_ON(auto_recover && !ops->recover) || + WARN_ON(graceful_period && !ops->recover)) { + reporter = ERR_PTR(-EINVAL); + goto unlock; + } + + reporter = kzalloc(sizeof(*reporter), GFP_KERNEL); + if (!reporter) { + reporter = ERR_PTR(-ENOMEM); + goto unlock; + } + + reporter->priv = priv; + reporter->ops = ops; + reporter->devlink = devlink; + reporter->graceful_period = graceful_period; + reporter->auto_recover = auto_recover; + mutex_init(&reporter->dump_lock); + list_add_tail(&reporter->list, &devlink->reporter_list); +unlock: + mutex_unlock(&devlink->lock); + return reporter; +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_create); + +/** + * devlink_health_reporter_destroy - destroy devlink health reporter + * + * @reporter: devlink health reporter to destroy + */ +void +devlink_health_reporter_destroy(struct devlink_health_reporter *reporter) +{ + mutex_lock(&reporter->devlink->lock); + list_del(&reporter->list); + mutex_unlock(&reporter->devlink->lock); + if (reporter->dump_fmsg) + devlink_fmsg_free(reporter->dump_fmsg); + kfree(reporter); +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_destroy); + +void +devlink_health_reporter_state_update(struct devlink_health_reporter *reporter, + enum devlink_health_reporter_state state) +{ + if (WARN_ON(state != DEVLINK_HEALTH_REPORTER_STATE_HEALTHY && + state != DEVLINK_HEALTH_REPORTER_STATE_ERROR)) + return; + + if (reporter->health_state == state) + return; + + reporter->health_state = state; + trace_devlink_health_reporter_state_update(reporter->devlink, + reporter->ops->name, state); +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_state_update); + +static int +devlink_health_reporter_recover(struct devlink_health_reporter *reporter, + void *priv_ctx) +{ + int err; + + if (!reporter->ops->recover) + return -EOPNOTSUPP; + + err = reporter->ops->recover(reporter, priv_ctx); + if (err) + return err; + + reporter->recovery_count++; + reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_HEALTHY; + reporter->last_recovery_ts = jiffies; + + return 0; +} + +static void +devlink_health_dump_clear(struct devlink_health_reporter *reporter) +{ + if (!reporter->dump_fmsg) + return; + devlink_fmsg_free(reporter->dump_fmsg); + reporter->dump_fmsg = NULL; +} + +static int devlink_health_do_dump(struct devlink_health_reporter *reporter, + void *priv_ctx) +{ + int err; + + if (!reporter->ops->dump) + return 0; + + if (reporter->dump_fmsg) + return 0; + + reporter->dump_fmsg = devlink_fmsg_alloc(); + if (!reporter->dump_fmsg) { + err = -ENOMEM; + return err; + } + + err = devlink_fmsg_obj_nest_start(reporter->dump_fmsg); + if (err) + goto dump_err; + + err = reporter->ops->dump(reporter, reporter->dump_fmsg, + priv_ctx); + if (err) + goto dump_err; + + err = devlink_fmsg_obj_nest_end(reporter->dump_fmsg); + if (err) + goto dump_err; + + reporter->dump_ts = jiffies; + + return 0; + +dump_err: + devlink_health_dump_clear(reporter); + return err; +} + +int devlink_health_report(struct devlink_health_reporter *reporter, + const char *msg, void *priv_ctx) +{ + enum devlink_health_reporter_state prev_health_state; + struct devlink *devlink = reporter->devlink; + + /* write a log message of the current error */ + WARN_ON(!msg); + trace_devlink_health_report(devlink, reporter->ops->name, msg); + reporter->error_count++; + prev_health_state = reporter->health_state; + reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_ERROR; + + /* abort if the previous error wasn't recovered */ + if (reporter->auto_recover && + (prev_health_state != DEVLINK_HEALTH_REPORTER_STATE_HEALTHY || + jiffies - reporter->last_recovery_ts < + msecs_to_jiffies(reporter->graceful_period))) { + trace_devlink_health_recover_aborted(devlink, + reporter->ops->name, + reporter->health_state, + jiffies - + reporter->last_recovery_ts); + return -ECANCELED; + } + + reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_ERROR; + + mutex_lock(&reporter->dump_lock); + /* store current dump of current error, for later analysis */ + devlink_health_do_dump(reporter, priv_ctx); + mutex_unlock(&reporter->dump_lock); + + if (reporter->auto_recover) + return devlink_health_reporter_recover(reporter, priv_ctx); + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_health_report); + +static struct devlink_health_reporter * +devlink_health_reporter_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + char *reporter_name; + + if (!info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_NAME]) + return NULL; + + reporter_name = + nla_data(info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_NAME]); + return devlink_health_reporter_find_by_name(devlink, reporter_name); +} + +static int +devlink_nl_health_reporter_fill(struct sk_buff *msg, + struct devlink *devlink, + struct devlink_health_reporter *reporter, + enum devlink_command cmd, u32 portid, + u32 seq, int flags) +{ + struct nlattr *reporter_attr; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto genlmsg_cancel; + + reporter_attr = nla_nest_start(msg, DEVLINK_ATTR_HEALTH_REPORTER); + if (!reporter_attr) + goto genlmsg_cancel; + if (nla_put_string(msg, DEVLINK_ATTR_HEALTH_REPORTER_NAME, + reporter->ops->name)) + goto reporter_nest_cancel; + if (nla_put_u8(msg, DEVLINK_ATTR_HEALTH_REPORTER_STATE, + reporter->health_state)) + goto reporter_nest_cancel; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_ERR_COUNT, + reporter->error_count, DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_RECOVER_COUNT, + reporter->recovery_count, DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (reporter->ops->recover && + nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD, + reporter->graceful_period, + DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (reporter->ops->recover && + nla_put_u8(msg, DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER, + reporter->auto_recover)) + goto reporter_nest_cancel; + if (reporter->dump_fmsg && + nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS, + jiffies_to_msecs(reporter->dump_ts), + DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + + nla_nest_end(msg, reporter_attr); + genlmsg_end(msg, hdr); + return 0; + +reporter_nest_cancel: + nla_nest_end(msg, reporter_attr); +genlmsg_cancel: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_health_reporter_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + struct sk_buff *msg; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_health_reporter_fill(msg, devlink, reporter, + DEVLINK_CMD_HEALTH_REPORTER_GET, + info->snd_portid, info->snd_seq, + 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int +devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_health_reporter *reporter; + struct devlink *devlink; + int start = cb->args[0]; + int idx = 0; + int err; + + mutex_lock(&devlink_mutex); + list_for_each_entry(devlink, &devlink_list, list) { + if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) + continue; + mutex_lock(&devlink->lock); + list_for_each_entry(reporter, &devlink->reporter_list, + list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_health_reporter_fill(msg, devlink, + reporter, + DEVLINK_CMD_HEALTH_REPORTER_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + mutex_unlock(&devlink->lock); + goto out; + } + idx++; + } + mutex_unlock(&devlink->lock); + } +out: mutex_unlock(&devlink_mutex); + + cb->args[0] = idx; + return msg->len; +} + +static int +devlink_nl_cmd_health_reporter_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->recover && + (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD] || + info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER])) + return -EOPNOTSUPP; + + if (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD]) + reporter->graceful_period = + nla_get_u64(info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD]); + + if (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER]) + reporter->auto_recover = + nla_get_u8(info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER]); + + return 0; +} + +static int devlink_nl_cmd_health_reporter_recover_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + return devlink_health_reporter_recover(reporter, NULL); +} + +static int devlink_nl_cmd_health_reporter_diagnose_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + struct devlink_fmsg *fmsg; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->diagnose) + return -EOPNOTSUPP; + + fmsg = devlink_fmsg_alloc(); + if (!fmsg) + return -ENOMEM; + + err = devlink_fmsg_obj_nest_start(fmsg); + if (err) + goto out; + + err = reporter->ops->diagnose(reporter, fmsg); + if (err) + goto out; + + err = devlink_fmsg_obj_nest_end(fmsg); + if (err) + goto out; + + err = devlink_fmsg_snd(fmsg, info, + DEVLINK_CMD_HEALTH_REPORTER_DIAGNOSE, 0); + +out: + devlink_fmsg_free(fmsg); + return err; +} + +static int devlink_nl_cmd_health_reporter_dump_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->dump) + return -EOPNOTSUPP; + + mutex_lock(&reporter->dump_lock); + err = devlink_health_do_dump(reporter, NULL); + if (err) + goto out; + + err = devlink_fmsg_snd(reporter->dump_fmsg, info, + DEVLINK_CMD_HEALTH_REPORTER_DUMP_GET, 0); + out: + mutex_unlock(&reporter->dump_lock); + return err; +} + +static int +devlink_nl_cmd_health_reporter_dump_clear_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->dump) + return -EOPNOTSUPP; + + mutex_lock(&reporter->dump_lock); + devlink_health_dump_clear(reporter); + mutex_unlock(&reporter->dump_lock); return 0; } @@ -3622,6 +4911,11 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = { [DEVLINK_ATTR_PARAM_VALUE_CMODE] = { .type = NLA_U8 }, [DEVLINK_ATTR_REGION_NAME] = { .type = NLA_NUL_STRING }, [DEVLINK_ATTR_REGION_SNAPSHOT_ID] = { .type = NLA_U32 }, + [DEVLINK_ATTR_HEALTH_REPORTER_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD] = { .type = NLA_U64 }, + [DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER] = { .type = NLA_U8 }, + [DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_FLASH_UPDATE_COMPONENT] = { .type = NLA_NUL_STRING }, }; static const struct genl_ops devlink_nl_ops[] = { @@ -3821,6 +5115,21 @@ static const struct genl_ops devlink_nl_ops[] = { .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, }, { + .cmd = DEVLINK_CMD_PORT_PARAM_GET, + .doit = devlink_nl_cmd_port_param_get_doit, + .dumpit = devlink_nl_cmd_port_param_get_dumpit, + .policy = devlink_nl_policy, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_PORT_PARAM_SET, + .doit = devlink_nl_cmd_port_param_set_doit, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { .cmd = DEVLINK_CMD_REGION_GET, .doit = devlink_nl_cmd_region_get_doit, .dumpit = devlink_nl_cmd_region_get_dumpit, @@ -3842,6 +5151,66 @@ static const struct genl_ops devlink_nl_ops[] = { .flags = GENL_ADMIN_PERM, .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, }, + { + .cmd = DEVLINK_CMD_INFO_GET, + .doit = devlink_nl_cmd_info_get_doit, + .dumpit = devlink_nl_cmd_info_get_dumpit, + .policy = devlink_nl_policy, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_GET, + .doit = devlink_nl_cmd_health_reporter_get_doit, + .dumpit = devlink_nl_cmd_health_reporter_get_dumpit, + .policy = devlink_nl_policy, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_SET, + .doit = devlink_nl_cmd_health_reporter_set_doit, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_RECOVER, + .doit = devlink_nl_cmd_health_reporter_recover_doit, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_DIAGNOSE, + .doit = devlink_nl_cmd_health_reporter_diagnose_doit, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_DUMP_GET, + .doit = devlink_nl_cmd_health_reporter_dump_get_doit, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK | + DEVLINK_NL_FLAG_NO_LOCK, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_DUMP_CLEAR, + .doit = devlink_nl_cmd_health_reporter_dump_clear_doit, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK | + DEVLINK_NL_FLAG_NO_LOCK, + }, + { + .cmd = DEVLINK_CMD_FLASH_UPDATE, + .doit = devlink_nl_cmd_flash_update, + .policy = devlink_nl_policy, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, + }, }; static struct genl_family devlink_nl_family __ro_after_init = { @@ -3871,6 +5240,9 @@ struct devlink *devlink_alloc(const struct devlink_ops *ops, size_t priv_size) { struct devlink *devlink; + if (WARN_ON(!ops)) + return NULL; + devlink = kzalloc(sizeof(*devlink) + priv_size, GFP_KERNEL); if (!devlink) return NULL; @@ -3882,6 +5254,7 @@ struct devlink *devlink_alloc(const struct devlink_ops *ops, size_t priv_size) INIT_LIST_HEAD(&devlink->resource_list); INIT_LIST_HEAD(&devlink->param_list); INIT_LIST_HEAD(&devlink->region_list); + INIT_LIST_HEAD(&devlink->reporter_list); mutex_init(&devlink->lock); return devlink; } @@ -3891,6 +5264,7 @@ EXPORT_SYMBOL_GPL(devlink_alloc); * devlink_register - Register devlink instance * * @devlink: devlink + * @dev: parent device */ int devlink_register(struct devlink *devlink, struct device *dev) { @@ -3924,6 +5298,14 @@ EXPORT_SYMBOL_GPL(devlink_unregister); */ void devlink_free(struct devlink *devlink) { + WARN_ON(!list_empty(&devlink->reporter_list)); + WARN_ON(!list_empty(&devlink->region_list)); + WARN_ON(!list_empty(&devlink->param_list)); + WARN_ON(!list_empty(&devlink->resource_list)); + WARN_ON(!list_empty(&devlink->dpipe_table_list)); + WARN_ON(!list_empty(&devlink->sb_list)); + WARN_ON(!list_empty(&devlink->port_list)); + kfree(devlink); } EXPORT_SYMBOL_GPL(devlink_free); @@ -3933,7 +5315,7 @@ EXPORT_SYMBOL_GPL(devlink_free); * * @devlink: devlink * @devlink_port: devlink port - * @port_index + * @port_index: driver-specific numerical identifier of the port * * Register devlink port with provided port index. User can use * any indexing, even hw-related one. devlink_port structure @@ -3954,6 +5336,7 @@ int devlink_port_register(struct devlink *devlink, devlink_port->index = port_index; devlink_port->registered = true; list_add_tail(&devlink_port->list, &devlink->port_list); + INIT_LIST_HEAD(&devlink_port->param_list); mutex_unlock(&devlink->lock); devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); return 0; @@ -4262,13 +5645,10 @@ EXPORT_SYMBOL_GPL(devlink_dpipe_table_unregister); * * @devlink: devlink * @resource_name: resource's name - * @top_hierarchy: top hierarchy - * @reload_required: reload is required for new configuration to - * apply * @resource_size: resource's size * @resource_id: resource's id - * @parent_reosurce_id: resource's parent id - * @size params: size parameters + * @parent_resource_id: resource's parent id + * @size_params: size parameters */ int devlink_resource_register(struct devlink *devlink, const char *resource_name, @@ -4471,18 +5851,23 @@ out: } EXPORT_SYMBOL_GPL(devlink_resource_occ_get_unregister); -/** - * devlink_params_register - register configuration parameters - * - * @devlink: devlink - * @params: configuration parameters array - * @params_count: number of parameters provided - * - * Register the configuration parameters supported by the driver. - */ -int devlink_params_register(struct devlink *devlink, - const struct devlink_param *params, - size_t params_count) +static int devlink_param_verify(const struct devlink_param *param) +{ + if (!param || !param->name || !param->supported_cmodes) + return -EINVAL; + if (param->generic) + return devlink_param_generic_verify(param); + else + return devlink_param_driver_verify(param); +} + +static int __devlink_params_register(struct devlink *devlink, + unsigned int port_index, + struct list_head *param_list, + const struct devlink_param *params, + size_t params_count, + enum devlink_command reg_cmd, + enum devlink_command unreg_cmd) { const struct devlink_param *param = params; int i; @@ -4490,20 +5875,12 @@ int devlink_params_register(struct devlink *devlink, mutex_lock(&devlink->lock); for (i = 0; i < params_count; i++, param++) { - if (!param || !param->name || !param->supported_cmodes) { - err = -EINVAL; + err = devlink_param_verify(param); + if (err) goto rollback; - } - if (param->generic) { - err = devlink_param_generic_verify(param); - if (err) - goto rollback; - } else { - err = devlink_param_driver_verify(param); - if (err) - goto rollback; - } - err = devlink_param_register_one(devlink, param); + + err = devlink_param_register_one(devlink, port_index, + param_list, param, reg_cmd); if (err) goto rollback; } @@ -4515,11 +5892,48 @@ rollback: if (!i) goto unlock; for (param--; i > 0; i--, param--) - devlink_param_unregister_one(devlink, param); + devlink_param_unregister_one(devlink, port_index, param_list, + param, unreg_cmd); unlock: mutex_unlock(&devlink->lock); return err; } + +static void __devlink_params_unregister(struct devlink *devlink, + unsigned int port_index, + struct list_head *param_list, + const struct devlink_param *params, + size_t params_count, + enum devlink_command cmd) +{ + const struct devlink_param *param = params; + int i; + + mutex_lock(&devlink->lock); + for (i = 0; i < params_count; i++, param++) + devlink_param_unregister_one(devlink, 0, param_list, param, + cmd); + mutex_unlock(&devlink->lock); +} + +/** + * devlink_params_register - register configuration parameters + * + * @devlink: devlink + * @params: configuration parameters array + * @params_count: number of parameters provided + * + * Register the configuration parameters supported by the driver. + */ +int devlink_params_register(struct devlink *devlink, + const struct devlink_param *params, + size_t params_count) +{ + return __devlink_params_register(devlink, 0, &devlink->param_list, + params, params_count, + DEVLINK_CMD_PARAM_NEW, + DEVLINK_CMD_PARAM_DEL); +} EXPORT_SYMBOL_GPL(devlink_params_register); /** @@ -4532,36 +5946,103 @@ void devlink_params_unregister(struct devlink *devlink, const struct devlink_param *params, size_t params_count) { - const struct devlink_param *param = params; - int i; - - mutex_lock(&devlink->lock); - for (i = 0; i < params_count; i++, param++) - devlink_param_unregister_one(devlink, param); - mutex_unlock(&devlink->lock); + return __devlink_params_unregister(devlink, 0, &devlink->param_list, + params, params_count, + DEVLINK_CMD_PARAM_DEL); } EXPORT_SYMBOL_GPL(devlink_params_unregister); /** - * devlink_param_driverinit_value_get - get configuration parameter - * value for driver initializing + * devlink_params_publish - publish configuration parameters * * @devlink: devlink - * @param_id: parameter ID - * @init_val: value of parameter in driverinit configuration mode * - * This function should be used by the driver to get driverinit - * configuration for initialization after reload command. + * Publish previously registered configuration parameters. */ -int devlink_param_driverinit_value_get(struct devlink *devlink, u32 param_id, - union devlink_param_value *init_val) +void devlink_params_publish(struct devlink *devlink) { struct devlink_param_item *param_item; - if (!devlink->ops || !devlink->ops->reload) - return -EOPNOTSUPP; + list_for_each_entry(param_item, &devlink->param_list, list) { + if (param_item->published) + continue; + param_item->published = true; + devlink_param_notify(devlink, 0, param_item, + DEVLINK_CMD_PARAM_NEW); + } +} +EXPORT_SYMBOL_GPL(devlink_params_publish); - param_item = devlink_param_find_by_id(&devlink->param_list, param_id); +/** + * devlink_params_unpublish - unpublish configuration parameters + * + * @devlink: devlink + * + * Unpublish previously registered configuration parameters. + */ +void devlink_params_unpublish(struct devlink *devlink) +{ + struct devlink_param_item *param_item; + + list_for_each_entry(param_item, &devlink->param_list, list) { + if (!param_item->published) + continue; + param_item->published = false; + devlink_param_notify(devlink, 0, param_item, + DEVLINK_CMD_PARAM_DEL); + } +} +EXPORT_SYMBOL_GPL(devlink_params_unpublish); + +/** + * devlink_port_params_register - register port configuration parameters + * + * @devlink_port: devlink port + * @params: configuration parameters array + * @params_count: number of parameters provided + * + * Register the configuration parameters supported by the port. + */ +int devlink_port_params_register(struct devlink_port *devlink_port, + const struct devlink_param *params, + size_t params_count) +{ + return __devlink_params_register(devlink_port->devlink, + devlink_port->index, + &devlink_port->param_list, params, + params_count, + DEVLINK_CMD_PORT_PARAM_NEW, + DEVLINK_CMD_PORT_PARAM_DEL); +} +EXPORT_SYMBOL_GPL(devlink_port_params_register); + +/** + * devlink_port_params_unregister - unregister port configuration + * parameters + * + * @devlink_port: devlink port + * @params: configuration parameters array + * @params_count: number of parameters provided + */ +void devlink_port_params_unregister(struct devlink_port *devlink_port, + const struct devlink_param *params, + size_t params_count) +{ + return __devlink_params_unregister(devlink_port->devlink, + devlink_port->index, + &devlink_port->param_list, + params, params_count, + DEVLINK_CMD_PORT_PARAM_DEL); +} +EXPORT_SYMBOL_GPL(devlink_port_params_unregister); + +static int +__devlink_param_driverinit_value_get(struct list_head *param_list, u32 param_id, + union devlink_param_value *init_val) +{ + struct devlink_param_item *param_item; + + param_item = devlink_param_find_by_id(param_list, param_id); if (!param_item) return -EINVAL; @@ -4577,6 +6058,54 @@ int devlink_param_driverinit_value_get(struct devlink *devlink, u32 param_id, return 0; } + +static int +__devlink_param_driverinit_value_set(struct devlink *devlink, + unsigned int port_index, + struct list_head *param_list, u32 param_id, + union devlink_param_value init_val, + enum devlink_command cmd) +{ + struct devlink_param_item *param_item; + + param_item = devlink_param_find_by_id(param_list, param_id); + if (!param_item) + return -EINVAL; + + if (!devlink_param_cmode_is_supported(param_item->param, + DEVLINK_PARAM_CMODE_DRIVERINIT)) + return -EOPNOTSUPP; + + if (param_item->param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(param_item->driverinit_value.vstr, init_val.vstr); + else + param_item->driverinit_value = init_val; + param_item->driverinit_value_valid = true; + + devlink_param_notify(devlink, port_index, param_item, cmd); + return 0; +} + +/** + * devlink_param_driverinit_value_get - get configuration parameter + * value for driver initializing + * + * @devlink: devlink + * @param_id: parameter ID + * @init_val: value of parameter in driverinit configuration mode + * + * This function should be used by the driver to get driverinit + * configuration for initialization after reload command. + */ +int devlink_param_driverinit_value_get(struct devlink *devlink, u32 param_id, + union devlink_param_value *init_val) +{ + if (!devlink->ops->reload) + return -EOPNOTSUPP; + + return __devlink_param_driverinit_value_get(&devlink->param_list, + param_id, init_val); +} EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_get); /** @@ -4594,26 +6123,61 @@ EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_get); int devlink_param_driverinit_value_set(struct devlink *devlink, u32 param_id, union devlink_param_value init_val) { - struct devlink_param_item *param_item; + return __devlink_param_driverinit_value_set(devlink, 0, + &devlink->param_list, + param_id, init_val, + DEVLINK_CMD_PARAM_NEW); +} +EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_set); - param_item = devlink_param_find_by_id(&devlink->param_list, param_id); - if (!param_item) - return -EINVAL; +/** + * devlink_port_param_driverinit_value_get - get configuration parameter + * value for driver initializing + * + * @devlink_port: devlink_port + * @param_id: parameter ID + * @init_val: value of parameter in driverinit configuration mode + * + * This function should be used by the driver to get driverinit + * configuration for initialization after reload command. + */ +int devlink_port_param_driverinit_value_get(struct devlink_port *devlink_port, + u32 param_id, + union devlink_param_value *init_val) +{ + struct devlink *devlink = devlink_port->devlink; - if (!devlink_param_cmode_is_supported(param_item->param, - DEVLINK_PARAM_CMODE_DRIVERINIT)) + if (!devlink->ops->reload) return -EOPNOTSUPP; - if (param_item->param->type == DEVLINK_PARAM_TYPE_STRING) - strcpy(param_item->driverinit_value.vstr, init_val.vstr); - else - param_item->driverinit_value = init_val; - param_item->driverinit_value_valid = true; + return __devlink_param_driverinit_value_get(&devlink_port->param_list, + param_id, init_val); +} +EXPORT_SYMBOL_GPL(devlink_port_param_driverinit_value_get); - devlink_param_notify(devlink, param_item, DEVLINK_CMD_PARAM_NEW); - return 0; +/** + * devlink_port_param_driverinit_value_set - set value of configuration + * parameter for driverinit + * configuration mode + * + * @devlink_port: devlink_port + * @param_id: parameter ID + * @init_val: value of parameter to set for driverinit configuration mode + * + * This function should be used by the driver to set driverinit + * configuration mode default value. + */ +int devlink_port_param_driverinit_value_set(struct devlink_port *devlink_port, + u32 param_id, + union devlink_param_value init_val) +{ + return __devlink_param_driverinit_value_set(devlink_port->devlink, + devlink_port->index, + &devlink_port->param_list, + param_id, init_val, + DEVLINK_CMD_PORT_PARAM_NEW); } -EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_set); +EXPORT_SYMBOL_GPL(devlink_port_param_driverinit_value_set); /** * devlink_param_value_changed - notify devlink on a parameter's value @@ -4626,7 +6190,6 @@ EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_set); * This function should be used by the driver to notify devlink on value * change, excluding driverinit configuration mode. * For driverinit configuration mode driver should use the function - * devlink_param_driverinit_value_set() instead. */ void devlink_param_value_changed(struct devlink *devlink, u32 param_id) { @@ -4635,11 +6198,38 @@ void devlink_param_value_changed(struct devlink *devlink, u32 param_id) param_item = devlink_param_find_by_id(&devlink->param_list, param_id); WARN_ON(!param_item); - devlink_param_notify(devlink, param_item, DEVLINK_CMD_PARAM_NEW); + devlink_param_notify(devlink, 0, param_item, DEVLINK_CMD_PARAM_NEW); } EXPORT_SYMBOL_GPL(devlink_param_value_changed); /** + * devlink_port_param_value_changed - notify devlink on a parameter's value + * change. Should be called by the driver + * right after the change. + * + * @devlink_port: devlink_port + * @param_id: parameter ID + * + * This function should be used by the driver to notify devlink on value + * change, excluding driverinit configuration mode. + * For driverinit configuration mode driver should use the function + * devlink_port_param_driverinit_value_set() instead. + */ +void devlink_port_param_value_changed(struct devlink_port *devlink_port, + u32 param_id) +{ + struct devlink_param_item *param_item; + + param_item = devlink_param_find_by_id(&devlink_port->param_list, + param_id); + WARN_ON(!param_item); + + devlink_param_notify(devlink_port->devlink, devlink_port->index, + param_item, DEVLINK_CMD_PORT_PARAM_NEW); +} +EXPORT_SYMBOL_GPL(devlink_port_param_value_changed); + +/** * devlink_param_value_str_fill - Safely fill-up the string preventing * from overflow of the preallocated buffer * @@ -4755,7 +6345,7 @@ EXPORT_SYMBOL_GPL(devlink_region_shapshot_id_get); * Multiple snapshots can be created on a region. * The @snapshot_id should be obtained using the getter function. * - * @devlink_region: devlink region of the snapshot + * @region: devlink region of the snapshot * @data_len: size of snapshot data * @data: snapshot data * @snapshot_id: snapshot id to be created @@ -4808,20 +6398,93 @@ unlock: } EXPORT_SYMBOL_GPL(devlink_region_snapshot_create); -static int __init devlink_module_init(void) +static void __devlink_compat_running_version(struct devlink *devlink, + char *buf, size_t len) { - return genl_register_family(&devlink_nl_family); + const struct nlattr *nlattr; + struct devlink_info_req req; + struct sk_buff *msg; + int rem, err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + req.msg = msg; + err = devlink->ops->info_get(devlink, &req, NULL); + if (err) + goto free_msg; + + nla_for_each_attr(nlattr, (void *)msg->data, msg->len, rem) { + const struct nlattr *kv; + int rem_kv; + + if (nla_type(nlattr) != DEVLINK_ATTR_INFO_VERSION_RUNNING) + continue; + + nla_for_each_nested(kv, nlattr, rem_kv) { + if (nla_type(kv) != DEVLINK_ATTR_INFO_VERSION_VALUE) + continue; + + strlcat(buf, nla_data(kv), len); + strlcat(buf, " ", len); + } + } +free_msg: + nlmsg_free(msg); +} + +void devlink_compat_running_version(struct net_device *dev, + char *buf, size_t len) +{ + struct devlink *devlink; + + dev_hold(dev); + rtnl_unlock(); + + mutex_lock(&devlink_mutex); + devlink = netdev_to_devlink(dev); + if (!devlink || !devlink->ops->info_get) + goto unlock_list; + + mutex_lock(&devlink->lock); + __devlink_compat_running_version(devlink, buf, len); + mutex_unlock(&devlink->lock); +unlock_list: + mutex_unlock(&devlink_mutex); + + rtnl_lock(); + dev_put(dev); } -static void __exit devlink_module_exit(void) +int devlink_compat_flash_update(struct net_device *dev, const char *file_name) { - genl_unregister_family(&devlink_nl_family); + struct devlink *devlink; + int ret = -EOPNOTSUPP; + + dev_hold(dev); + rtnl_unlock(); + + mutex_lock(&devlink_mutex); + devlink = netdev_to_devlink(dev); + if (!devlink || !devlink->ops->flash_update) + goto unlock_list; + + mutex_lock(&devlink->lock); + ret = devlink->ops->flash_update(devlink, file_name, NULL, NULL); + mutex_unlock(&devlink->lock); +unlock_list: + mutex_unlock(&devlink_mutex); + + rtnl_lock(); + dev_put(dev); + + return ret; } -module_init(devlink_module_init); -module_exit(devlink_module_exit); +static int __init devlink_init(void) +{ + return genl_register_family(&devlink_nl_family); +} -MODULE_LICENSE("GPL v2"); -MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>"); -MODULE_DESCRIPTION("Network physical device Netlink interface"); -MODULE_ALIAS_GENL_FAMILY(DEVLINK_GENL_NAME); +subsys_initcall(devlink_init); diff --git a/net/core/dst.c b/net/core/dst.c index 81ccf20e2826..a263309df115 100644 --- a/net/core/dst.c +++ b/net/core/dst.c @@ -98,8 +98,12 @@ void *dst_alloc(struct dst_ops *ops, struct net_device *dev, struct dst_entry *dst; if (ops->gc && dst_entries_get_fast(ops) > ops->gc_thresh) { - if (ops->gc(ops)) + if (ops->gc(ops)) { + printk_ratelimited(KERN_NOTICE "Route cache is full: " + "consider increasing sysctl " + "net.ipv[4|6].route.max_size.\n"); return NULL; + } } dst = kmem_cache_alloc(ops->kmem_cachep, GFP_ATOMIC); diff --git a/net/core/ethtool.c b/net/core/ethtool.c index 158264f7cfaf..b1eb32419732 100644 --- a/net/core/ethtool.c +++ b/net/core/ethtool.c @@ -27,7 +27,9 @@ #include <linux/rtnetlink.h> #include <linux/sched/signal.h> #include <linux/net.h> +#include <net/devlink.h> #include <net/xdp_sock.h> +#include <net/flow_offload.h> /* * Some useful ethtool_ops methods that're device independent. @@ -803,6 +805,10 @@ static noinline_for_stack int ethtool_get_drvinfo(struct net_device *dev, if (ops->get_eeprom_len) info.eedump_len = ops->get_eeprom_len(dev); + if (!info.fw_version[0]) + devlink_compat_running_version(dev, info.fw_version, + sizeof(info.fw_version)); + if (copy_to_user(useraddr, &info, sizeof(info))) return -EFAULT; return 0; @@ -1348,12 +1354,9 @@ static int ethtool_get_regs(struct net_device *dev, char __user *useraddr) if (regs.len > reglen) regs.len = reglen; - regbuf = NULL; - if (reglen) { - regbuf = vzalloc(reglen); - if (!regbuf) - return -ENOMEM; - } + regbuf = vzalloc(reglen); + if (!regbuf) + return -ENOMEM; ops->get_regs(dev, ®s, regbuf); @@ -1714,7 +1717,7 @@ static noinline_for_stack int ethtool_set_channels(struct net_device *dev, static int ethtool_get_pauseparam(struct net_device *dev, void __user *useraddr) { - struct ethtool_pauseparam pauseparam = { ETHTOOL_GPAUSEPARAM }; + struct ethtool_pauseparam pauseparam = { .cmd = ETHTOOL_GPAUSEPARAM }; if (!dev->ethtool_ops->get_pauseparam) return -EOPNOTSUPP; @@ -2033,11 +2036,10 @@ static noinline_for_stack int ethtool_flash_device(struct net_device *dev, if (copy_from_user(&efl, useraddr, sizeof(efl))) return -EFAULT; + efl.data[ETHTOOL_FLASH_MAX_FILENAME - 1] = 0; if (!dev->ethtool_ops->flash_device) - return -EOPNOTSUPP; - - efl.data[ETHTOOL_FLASH_MAX_FILENAME - 1] = 0; + return devlink_compat_flash_update(dev, efl.data); return dev->ethtool_ops->flash_device(dev, &efl); } @@ -2317,9 +2319,10 @@ static int ethtool_set_tunable(struct net_device *dev, void __user *useraddr) return ret; } -static int ethtool_get_per_queue_coalesce(struct net_device *dev, - void __user *useraddr, - struct ethtool_per_queue_op *per_queue_opt) +static noinline_for_stack int +ethtool_get_per_queue_coalesce(struct net_device *dev, + void __user *useraddr, + struct ethtool_per_queue_op *per_queue_opt) { u32 bit; int ret; @@ -2347,9 +2350,10 @@ static int ethtool_get_per_queue_coalesce(struct net_device *dev, return 0; } -static int ethtool_set_per_queue_coalesce(struct net_device *dev, - void __user *useraddr, - struct ethtool_per_queue_op *per_queue_opt) +static noinline_for_stack int +ethtool_set_per_queue_coalesce(struct net_device *dev, + void __user *useraddr, + struct ethtool_per_queue_op *per_queue_opt) { u32 bit; int i, ret = 0; @@ -2403,7 +2407,7 @@ roll_back: return ret; } -static int ethtool_set_per_queue(struct net_device *dev, +static int noinline_for_stack ethtool_set_per_queue(struct net_device *dev, void __user *useraddr, u32 sub_cmd) { struct ethtool_per_queue_op per_queue_opt; @@ -2501,7 +2505,7 @@ static int set_phy_tunable(struct net_device *dev, void __user *useraddr) static int ethtool_get_fecparam(struct net_device *dev, void __user *useraddr) { - struct ethtool_fecparam fecparam = { ETHTOOL_GFECPARAM }; + struct ethtool_fecparam fecparam = { .cmd = ETHTOOL_GFECPARAM }; int rc; if (!dev->ethtool_ops->get_fecparam) @@ -2816,3 +2820,241 @@ int dev_ethtool(struct net *net, struct ifreq *ifr) return rc; } + +struct ethtool_rx_flow_key { + struct flow_dissector_key_basic basic; + union { + struct flow_dissector_key_ipv4_addrs ipv4; + struct flow_dissector_key_ipv6_addrs ipv6; + }; + struct flow_dissector_key_ports tp; + struct flow_dissector_key_ip ip; + struct flow_dissector_key_vlan vlan; + struct flow_dissector_key_eth_addrs eth_addrs; +} __aligned(BITS_PER_LONG / 8); /* Ensure that we can do comparisons as longs. */ + +struct ethtool_rx_flow_match { + struct flow_dissector dissector; + struct ethtool_rx_flow_key key; + struct ethtool_rx_flow_key mask; +}; + +struct ethtool_rx_flow_rule * +ethtool_rx_flow_rule_create(const struct ethtool_rx_flow_spec_input *input) +{ + const struct ethtool_rx_flow_spec *fs = input->fs; + static struct in6_addr zero_addr = {}; + struct ethtool_rx_flow_match *match; + struct ethtool_rx_flow_rule *flow; + struct flow_action_entry *act; + + flow = kzalloc(sizeof(struct ethtool_rx_flow_rule) + + sizeof(struct ethtool_rx_flow_match), GFP_KERNEL); + if (!flow) + return ERR_PTR(-ENOMEM); + + /* ethtool_rx supports only one single action per rule. */ + flow->rule = flow_rule_alloc(1); + if (!flow->rule) { + kfree(flow); + return ERR_PTR(-ENOMEM); + } + + match = (struct ethtool_rx_flow_match *)flow->priv; + flow->rule->match.dissector = &match->dissector; + flow->rule->match.mask = &match->mask; + flow->rule->match.key = &match->key; + + match->mask.basic.n_proto = htons(0xffff); + + switch (fs->flow_type & ~(FLOW_EXT | FLOW_MAC_EXT | FLOW_RSS)) { + case TCP_V4_FLOW: + case UDP_V4_FLOW: { + const struct ethtool_tcpip4_spec *v4_spec, *v4_m_spec; + + match->key.basic.n_proto = htons(ETH_P_IP); + + v4_spec = &fs->h_u.tcp_ip4_spec; + v4_m_spec = &fs->m_u.tcp_ip4_spec; + + if (v4_m_spec->ip4src) { + match->key.ipv4.src = v4_spec->ip4src; + match->mask.ipv4.src = v4_m_spec->ip4src; + } + if (v4_m_spec->ip4dst) { + match->key.ipv4.dst = v4_spec->ip4dst; + match->mask.ipv4.dst = v4_m_spec->ip4dst; + } + if (v4_m_spec->ip4src || + v4_m_spec->ip4dst) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IPV4_ADDRS); + match->dissector.offset[FLOW_DISSECTOR_KEY_IPV4_ADDRS] = + offsetof(struct ethtool_rx_flow_key, ipv4); + } + if (v4_m_spec->psrc) { + match->key.tp.src = v4_spec->psrc; + match->mask.tp.src = v4_m_spec->psrc; + } + if (v4_m_spec->pdst) { + match->key.tp.dst = v4_spec->pdst; + match->mask.tp.dst = v4_m_spec->pdst; + } + if (v4_m_spec->psrc || + v4_m_spec->pdst) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_PORTS); + match->dissector.offset[FLOW_DISSECTOR_KEY_PORTS] = + offsetof(struct ethtool_rx_flow_key, tp); + } + if (v4_m_spec->tos) { + match->key.ip.tos = v4_spec->tos; + match->mask.ip.tos = v4_m_spec->tos; + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IP); + match->dissector.offset[FLOW_DISSECTOR_KEY_IP] = + offsetof(struct ethtool_rx_flow_key, ip); + } + } + break; + case TCP_V6_FLOW: + case UDP_V6_FLOW: { + const struct ethtool_tcpip6_spec *v6_spec, *v6_m_spec; + + match->key.basic.n_proto = htons(ETH_P_IPV6); + + v6_spec = &fs->h_u.tcp_ip6_spec; + v6_m_spec = &fs->m_u.tcp_ip6_spec; + if (memcmp(v6_m_spec->ip6src, &zero_addr, sizeof(zero_addr))) { + memcpy(&match->key.ipv6.src, v6_spec->ip6src, + sizeof(match->key.ipv6.src)); + memcpy(&match->mask.ipv6.src, v6_m_spec->ip6src, + sizeof(match->mask.ipv6.src)); + } + if (memcmp(v6_m_spec->ip6dst, &zero_addr, sizeof(zero_addr))) { + memcpy(&match->key.ipv6.dst, v6_spec->ip6dst, + sizeof(match->key.ipv6.dst)); + memcpy(&match->mask.ipv6.dst, v6_m_spec->ip6dst, + sizeof(match->mask.ipv6.dst)); + } + if (memcmp(v6_m_spec->ip6src, &zero_addr, sizeof(zero_addr)) || + memcmp(v6_m_spec->ip6src, &zero_addr, sizeof(zero_addr))) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IPV6_ADDRS); + match->dissector.offset[FLOW_DISSECTOR_KEY_IPV6_ADDRS] = + offsetof(struct ethtool_rx_flow_key, ipv6); + } + if (v6_m_spec->psrc) { + match->key.tp.src = v6_spec->psrc; + match->mask.tp.src = v6_m_spec->psrc; + } + if (v6_m_spec->pdst) { + match->key.tp.dst = v6_spec->pdst; + match->mask.tp.dst = v6_m_spec->pdst; + } + if (v6_m_spec->psrc || + v6_m_spec->pdst) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_PORTS); + match->dissector.offset[FLOW_DISSECTOR_KEY_PORTS] = + offsetof(struct ethtool_rx_flow_key, tp); + } + if (v6_m_spec->tclass) { + match->key.ip.tos = v6_spec->tclass; + match->mask.ip.tos = v6_m_spec->tclass; + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IP); + match->dissector.offset[FLOW_DISSECTOR_KEY_IP] = + offsetof(struct ethtool_rx_flow_key, ip); + } + } + break; + default: + ethtool_rx_flow_rule_destroy(flow); + return ERR_PTR(-EINVAL); + } + + switch (fs->flow_type & ~(FLOW_EXT | FLOW_MAC_EXT | FLOW_RSS)) { + case TCP_V4_FLOW: + case TCP_V6_FLOW: + match->key.basic.ip_proto = IPPROTO_TCP; + break; + case UDP_V4_FLOW: + case UDP_V6_FLOW: + match->key.basic.ip_proto = IPPROTO_UDP; + break; + } + match->mask.basic.ip_proto = 0xff; + + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_BASIC); + match->dissector.offset[FLOW_DISSECTOR_KEY_BASIC] = + offsetof(struct ethtool_rx_flow_key, basic); + + if (fs->flow_type & FLOW_EXT) { + const struct ethtool_flow_ext *ext_h_spec = &fs->h_ext; + const struct ethtool_flow_ext *ext_m_spec = &fs->m_ext; + + if (ext_m_spec->vlan_etype && + ext_m_spec->vlan_tci) { + match->key.vlan.vlan_tpid = ext_h_spec->vlan_etype; + match->mask.vlan.vlan_tpid = ext_m_spec->vlan_etype; + + match->key.vlan.vlan_id = + ntohs(ext_h_spec->vlan_tci) & 0x0fff; + match->mask.vlan.vlan_id = + ntohs(ext_m_spec->vlan_tci) & 0x0fff; + + match->key.vlan.vlan_priority = + (ntohs(ext_h_spec->vlan_tci) & 0xe000) >> 13; + match->mask.vlan.vlan_priority = + (ntohs(ext_m_spec->vlan_tci) & 0xe000) >> 13; + + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_VLAN); + match->dissector.offset[FLOW_DISSECTOR_KEY_VLAN] = + offsetof(struct ethtool_rx_flow_key, vlan); + } + } + if (fs->flow_type & FLOW_MAC_EXT) { + const struct ethtool_flow_ext *ext_h_spec = &fs->h_ext; + const struct ethtool_flow_ext *ext_m_spec = &fs->m_ext; + + memcpy(match->key.eth_addrs.dst, ext_h_spec->h_dest, + ETH_ALEN); + memcpy(match->mask.eth_addrs.dst, ext_m_spec->h_dest, + ETH_ALEN); + + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_ETH_ADDRS); + match->dissector.offset[FLOW_DISSECTOR_KEY_ETH_ADDRS] = + offsetof(struct ethtool_rx_flow_key, eth_addrs); + } + + act = &flow->rule->action.entries[0]; + switch (fs->ring_cookie) { + case RX_CLS_FLOW_DISC: + act->id = FLOW_ACTION_DROP; + break; + case RX_CLS_FLOW_WAKE: + act->id = FLOW_ACTION_WAKE; + break; + default: + act->id = FLOW_ACTION_QUEUE; + if (fs->flow_type & FLOW_RSS) + act->queue.ctx = input->rss_ctx; + + act->queue.vf = ethtool_get_flow_spec_ring_vf(fs->ring_cookie); + act->queue.index = ethtool_get_flow_spec_ring(fs->ring_cookie); + break; + } + + return flow; +} +EXPORT_SYMBOL(ethtool_rx_flow_rule_create); + +void ethtool_rx_flow_rule_destroy(struct ethtool_rx_flow_rule *flow) +{ + kfree(flow->rule); + kfree(flow); +} +EXPORT_SYMBOL(ethtool_rx_flow_rule_destroy); diff --git a/net/core/filter.c b/net/core/filter.c index 7559d6835ecb..f274620945ff 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -73,6 +73,7 @@ #include <linux/seg6_local.h> #include <net/seg6.h> #include <net/seg6_local.h> +#include <net/lwtunnel.h> /** * sk_filter_trim_cap - run a packet through a socket filter @@ -1793,6 +1794,20 @@ static const struct bpf_func_proto bpf_skb_pull_data_proto = { .arg2_type = ARG_ANYTHING, }; +BPF_CALL_1(bpf_sk_fullsock, struct sock *, sk) +{ + sk = sk_to_full_sk(sk); + + return sk_fullsock(sk) ? (unsigned long)sk : (unsigned long)NULL; +} + +static const struct bpf_func_proto bpf_sk_fullsock_proto = { + .func = bpf_sk_fullsock, + .gpl_only = false, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_SOCK_COMMON, +}; + static inline int sk_skb_try_make_writable(struct sk_buff *skb, unsigned int write_len) { @@ -2789,8 +2804,7 @@ static int bpf_skb_proto_4_to_6(struct sk_buff *skb) u32 off = skb_mac_header_len(skb); int ret; - /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */ - if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb))) + if (skb_is_gso(skb) && !skb_is_gso_tcp(skb)) return -ENOTSUPP; ret = skb_cow(skb, len_diff); @@ -2831,8 +2845,7 @@ static int bpf_skb_proto_6_to_4(struct sk_buff *skb) u32 off = skb_mac_header_len(skb); int ret; - /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */ - if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb))) + if (skb_is_gso(skb) && !skb_is_gso_tcp(skb)) return -ENOTSUPP; ret = skb_unclone(skb, GFP_ATOMIC); @@ -2957,8 +2970,7 @@ static int bpf_skb_net_grow(struct sk_buff *skb, u32 len_diff) u32 off = skb_mac_header_len(skb) + bpf_skb_net_base_len(skb); int ret; - /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */ - if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb))) + if (skb_is_gso(skb) && !skb_is_gso_tcp(skb)) return -ENOTSUPP; ret = skb_cow(skb, len_diff); @@ -2987,8 +2999,7 @@ static int bpf_skb_net_shrink(struct sk_buff *skb, u32 len_diff) u32 off = skb_mac_header_len(skb) + bpf_skb_net_base_len(skb); int ret; - /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */ - if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb))) + if (skb_is_gso(skb) && !skb_is_gso_tcp(skb)) return -ENOTSUPP; ret = skb_unclone(skb, GFP_ATOMIC); @@ -4112,10 +4123,12 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock, /* Only some socketops are supported */ switch (optname) { case SO_RCVBUF: + val = min_t(u32, val, sysctl_rmem_max); sk->sk_userlocks |= SOCK_RCVBUF_LOCK; sk->sk_rcvbuf = max_t(int, val * 2, SOCK_MIN_RCVBUF); break; case SO_SNDBUF: + val = min_t(u32, val, sysctl_wmem_max); sk->sk_userlocks |= SOCK_SNDBUF_LOCK; sk->sk_sndbuf = max_t(int, val * 2, SOCK_MIN_SNDBUF); break; @@ -4801,7 +4814,15 @@ static int bpf_push_seg6_encap(struct sk_buff *skb, u32 type, void *hdr, u32 len } #endif /* CONFIG_IPV6_SEG6_BPF */ -BPF_CALL_4(bpf_lwt_push_encap, struct sk_buff *, skb, u32, type, void *, hdr, +#if IS_ENABLED(CONFIG_LWTUNNEL_BPF) +static int bpf_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, + bool ingress) +{ + return bpf_lwt_push_ip_encap(skb, hdr, len, ingress); +} +#endif + +BPF_CALL_4(bpf_lwt_in_push_encap, struct sk_buff *, skb, u32, type, void *, hdr, u32, len) { switch (type) { @@ -4810,13 +4831,40 @@ BPF_CALL_4(bpf_lwt_push_encap, struct sk_buff *, skb, u32, type, void *, hdr, case BPF_LWT_ENCAP_SEG6_INLINE: return bpf_push_seg6_encap(skb, type, hdr, len); #endif +#if IS_ENABLED(CONFIG_LWTUNNEL_BPF) + case BPF_LWT_ENCAP_IP: + return bpf_push_ip_encap(skb, hdr, len, true /* ingress */); +#endif + default: + return -EINVAL; + } +} + +BPF_CALL_4(bpf_lwt_xmit_push_encap, struct sk_buff *, skb, u32, type, + void *, hdr, u32, len) +{ + switch (type) { +#if IS_ENABLED(CONFIG_LWTUNNEL_BPF) + case BPF_LWT_ENCAP_IP: + return bpf_push_ip_encap(skb, hdr, len, false /* egress */); +#endif default: return -EINVAL; } } -static const struct bpf_func_proto bpf_lwt_push_encap_proto = { - .func = bpf_lwt_push_encap, +static const struct bpf_func_proto bpf_lwt_in_push_encap_proto = { + .func = bpf_lwt_in_push_encap, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_MEM, + .arg4_type = ARG_CONST_SIZE +}; + +static const struct bpf_func_proto bpf_lwt_xmit_push_encap_proto = { + .func = bpf_lwt_xmit_push_encap, .gpl_only = false, .ret_type = RET_INTEGER, .arg1_type = ARG_PTR_TO_CTX, @@ -5016,6 +5064,54 @@ static const struct bpf_func_proto bpf_lwt_seg6_adjust_srh_proto = { }; #endif /* CONFIG_IPV6_SEG6_BPF */ +#define CONVERT_COMMON_TCP_SOCK_FIELDS(md_type, CONVERT) \ +do { \ + switch (si->off) { \ + case offsetof(md_type, snd_cwnd): \ + CONVERT(snd_cwnd); break; \ + case offsetof(md_type, srtt_us): \ + CONVERT(srtt_us); break; \ + case offsetof(md_type, snd_ssthresh): \ + CONVERT(snd_ssthresh); break; \ + case offsetof(md_type, rcv_nxt): \ + CONVERT(rcv_nxt); break; \ + case offsetof(md_type, snd_nxt): \ + CONVERT(snd_nxt); break; \ + case offsetof(md_type, snd_una): \ + CONVERT(snd_una); break; \ + case offsetof(md_type, mss_cache): \ + CONVERT(mss_cache); break; \ + case offsetof(md_type, ecn_flags): \ + CONVERT(ecn_flags); break; \ + case offsetof(md_type, rate_delivered): \ + CONVERT(rate_delivered); break; \ + case offsetof(md_type, rate_interval_us): \ + CONVERT(rate_interval_us); break; \ + case offsetof(md_type, packets_out): \ + CONVERT(packets_out); break; \ + case offsetof(md_type, retrans_out): \ + CONVERT(retrans_out); break; \ + case offsetof(md_type, total_retrans): \ + CONVERT(total_retrans); break; \ + case offsetof(md_type, segs_in): \ + CONVERT(segs_in); break; \ + case offsetof(md_type, data_segs_in): \ + CONVERT(data_segs_in); break; \ + case offsetof(md_type, segs_out): \ + CONVERT(segs_out); break; \ + case offsetof(md_type, data_segs_out): \ + CONVERT(data_segs_out); break; \ + case offsetof(md_type, lost_out): \ + CONVERT(lost_out); break; \ + case offsetof(md_type, sacked_out): \ + CONVERT(sacked_out); break; \ + case offsetof(md_type, bytes_received): \ + CONVERT(bytes_received); break; \ + case offsetof(md_type, bytes_acked): \ + CONVERT(bytes_acked); break; \ + } \ +} while (0) + #ifdef CONFIG_INET static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple, int dif, int sdif, u8 family, u8 proto) @@ -5253,6 +5349,105 @@ static const struct bpf_func_proto bpf_sock_addr_sk_lookup_udp_proto = { .arg5_type = ARG_ANYTHING, }; +bool bpf_tcp_sock_is_valid_access(int off, int size, enum bpf_access_type type, + struct bpf_insn_access_aux *info) +{ + if (off < 0 || off >= offsetofend(struct bpf_tcp_sock, bytes_acked)) + return false; + + if (off % size != 0) + return false; + + switch (off) { + case offsetof(struct bpf_tcp_sock, bytes_received): + case offsetof(struct bpf_tcp_sock, bytes_acked): + return size == sizeof(__u64); + default: + return size == sizeof(__u32); + } +} + +u32 bpf_tcp_sock_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + +#define BPF_TCP_SOCK_GET_COMMON(FIELD) \ + do { \ + BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, FIELD) > \ + FIELD_SIZEOF(struct bpf_tcp_sock, FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct tcp_sock, FIELD),\ + si->dst_reg, si->src_reg, \ + offsetof(struct tcp_sock, FIELD)); \ + } while (0) + + CONVERT_COMMON_TCP_SOCK_FIELDS(struct bpf_tcp_sock, + BPF_TCP_SOCK_GET_COMMON); + + if (insn > insn_buf) + return insn - insn_buf; + + switch (si->off) { + case offsetof(struct bpf_tcp_sock, rtt_min): + BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, rtt_min) != + sizeof(struct minmax)); + BUILD_BUG_ON(sizeof(struct minmax) < + sizeof(struct minmax_sample)); + + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct tcp_sock, rtt_min) + + offsetof(struct minmax_sample, v)); + break; + } + + return insn - insn_buf; +} + +BPF_CALL_1(bpf_tcp_sock, struct sock *, sk) +{ + sk = sk_to_full_sk(sk); + + if (sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +static const struct bpf_func_proto bpf_tcp_sock_proto = { + .func = bpf_tcp_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_TCP_SOCK_OR_NULL, + .arg1_type = ARG_PTR_TO_SOCK_COMMON, +}; + +BPF_CALL_1(bpf_skb_ecn_set_ce, struct sk_buff *, skb) +{ + unsigned int iphdr_len; + + if (skb->protocol == cpu_to_be16(ETH_P_IP)) + iphdr_len = sizeof(struct iphdr); + else if (skb->protocol == cpu_to_be16(ETH_P_IPV6)) + iphdr_len = sizeof(struct ipv6hdr); + else + return 0; + + if (skb_headlen(skb) < iphdr_len) + return 0; + + if (skb_cloned(skb) && !skb_clone_writable(skb, iphdr_len)) + return 0; + + return INET_ECN_set_ce(skb); +} + +static const struct bpf_func_proto bpf_skb_ecn_set_ce_proto = { + .func = bpf_skb_ecn_set_ce, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; #endif /* CONFIG_INET */ bool bpf_helper_changes_pkt_data(void *func) @@ -5282,7 +5477,8 @@ bool bpf_helper_changes_pkt_data(void *func) func == bpf_lwt_seg6_adjust_srh || func == bpf_lwt_seg6_action || #endif - func == bpf_lwt_push_encap) + func == bpf_lwt_in_push_encap || + func == bpf_lwt_xmit_push_encap) return true; return false; @@ -5314,10 +5510,20 @@ bpf_base_func_proto(enum bpf_func_id func_id) return &bpf_tail_call_proto; case BPF_FUNC_ktime_get_ns: return &bpf_ktime_get_ns_proto; + default: + break; + } + + if (!capable(CAP_SYS_ADMIN)) + return NULL; + + switch (func_id) { + case BPF_FUNC_spin_lock: + return &bpf_spin_lock_proto; + case BPF_FUNC_spin_unlock: + return &bpf_spin_unlock_proto; case BPF_FUNC_trace_printk: - if (capable(CAP_SYS_ADMIN)) - return bpf_get_trace_printk_proto(); - /* else, fall through */ + return bpf_get_trace_printk_proto(); default: return NULL; } @@ -5396,6 +5602,14 @@ cg_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) switch (func_id) { case BPF_FUNC_get_local_storage: return &bpf_get_local_storage_proto; + case BPF_FUNC_sk_fullsock: + return &bpf_sk_fullsock_proto; +#ifdef CONFIG_INET + case BPF_FUNC_tcp_sock: + return &bpf_tcp_sock_proto; + case BPF_FUNC_skb_ecn_set_ce: + return &bpf_skb_ecn_set_ce_proto; +#endif default: return sk_filter_func_proto(func_id, prog); } @@ -5467,6 +5681,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_get_socket_uid_proto; case BPF_FUNC_fib_lookup: return &bpf_skb_fib_lookup_proto; + case BPF_FUNC_sk_fullsock: + return &bpf_sk_fullsock_proto; #ifdef CONFIG_XFRM case BPF_FUNC_skb_get_xfrm_state: return &bpf_skb_get_xfrm_state_proto; @@ -5484,6 +5700,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_sk_lookup_udp_proto; case BPF_FUNC_sk_release: return &bpf_sk_release_proto; + case BPF_FUNC_tcp_sock: + return &bpf_tcp_sock_proto; #endif default: return bpf_base_func_proto(func_id); @@ -5660,7 +5878,7 @@ lwt_in_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { switch (func_id) { case BPF_FUNC_lwt_push_encap: - return &bpf_lwt_push_encap_proto; + return &bpf_lwt_in_push_encap_proto; default: return lwt_out_func_proto(func_id, prog); } @@ -5696,6 +5914,8 @@ lwt_xmit_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_l4_csum_replace_proto; case BPF_FUNC_set_hash_invalid: return &bpf_set_hash_invalid_proto; + case BPF_FUNC_lwt_push_encap: + return &bpf_lwt_xmit_push_encap_proto; default: return lwt_out_func_proto(func_id, prog); } @@ -5754,6 +5974,11 @@ static bool bpf_skb_is_valid_access(int off, int size, enum bpf_access_type type if (size != sizeof(__u64)) return false; break; + case offsetof(struct __sk_buff, sk): + if (type == BPF_WRITE || size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_SOCK_COMMON_OR_NULL; + break; default: /* Only narrow read access allowed for now. */ if (type == BPF_WRITE) { @@ -5925,31 +6150,44 @@ full_access: return true; } -static bool __sock_filter_check_size(int off, int size, +bool bpf_sock_common_is_valid_access(int off, int size, + enum bpf_access_type type, struct bpf_insn_access_aux *info) { - const int size_default = sizeof(__u32); - switch (off) { - case bpf_ctx_range(struct bpf_sock, src_ip4): - case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): - bpf_ctx_record_field_size(info, size_default); - return bpf_ctx_narrow_access_ok(off, size, size_default); + case bpf_ctx_range_till(struct bpf_sock, type, priority): + return false; + default: + return bpf_sock_is_valid_access(off, size, type, info); } - - return size == size_default; } bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type, struct bpf_insn_access_aux *info) { + const int size_default = sizeof(__u32); + if (off < 0 || off >= sizeof(struct bpf_sock)) return false; if (off % size != 0) return false; - if (!__sock_filter_check_size(off, size, info)) - return false; - return true; + + switch (off) { + case offsetof(struct bpf_sock, state): + case offsetof(struct bpf_sock, family): + case offsetof(struct bpf_sock, type): + case offsetof(struct bpf_sock, protocol): + case offsetof(struct bpf_sock, dst_port): + case offsetof(struct bpf_sock, src_port): + case bpf_ctx_range(struct bpf_sock, src_ip4): + case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): + case bpf_ctx_range(struct bpf_sock, dst_ip4): + case bpf_ctx_range_till(struct bpf_sock, dst_ip6[0], dst_ip6[3]): + bpf_ctx_record_field_size(info, size_default); + return bpf_ctx_narrow_access_ok(off, size, size_default); + } + + return size == size_default; } static bool sock_filter_is_valid_access(int off, int size, @@ -6065,6 +6303,7 @@ static bool tc_cls_act_is_valid_access(int off, int size, case bpf_ctx_range(struct __sk_buff, tc_classid): case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): case bpf_ctx_range(struct __sk_buff, tstamp): + case bpf_ctx_range(struct __sk_buff, queue_mapping): break; default: return false; @@ -6469,9 +6708,18 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct __sk_buff, queue_mapping): - *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, - bpf_target_off(struct sk_buff, queue_mapping, 2, - target_size)); + if (type == BPF_WRITE) { + *insn++ = BPF_JMP_IMM(BPF_JGE, si->src_reg, NO_QUEUE_MAPPING, 1); + *insn++ = BPF_STX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, + queue_mapping, + 2, target_size)); + } else { + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, + queue_mapping, + 2, target_size)); + } break; case offsetof(struct __sk_buff, vlan_present): @@ -6708,6 +6956,27 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, target_size)); break; + case offsetof(struct __sk_buff, gso_segs): + /* si->dst_reg = skb_shinfo(SKB); */ +#ifdef NET_SKBUFF_DATA_USES_OFFSET + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, head), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, head)); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), + BPF_REG_AX, si->src_reg, + offsetof(struct sk_buff, end)); + *insn++ = BPF_ALU64_REG(BPF_ADD, si->dst_reg, BPF_REG_AX); +#else + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, end)); +#endif + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct skb_shared_info, gso_segs), + si->dst_reg, si->dst_reg, + bpf_target_off(struct skb_shared_info, + gso_segs, 2, + target_size)); + break; case offsetof(struct __sk_buff, wire_len): BUILD_BUG_ON(FIELD_SIZEOF(struct qdisc_skb_cb, pkt_len) != 4); @@ -6717,6 +6986,13 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, off += offsetof(struct qdisc_skb_cb, pkt_len); *target_size = 4; *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, off); + break; + + case offsetof(struct __sk_buff, sk): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + break; } return insn - insn_buf; @@ -6765,24 +7041,32 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, break; case offsetof(struct bpf_sock, family): - BUILD_BUG_ON(FIELD_SIZEOF(struct sock, sk_family) != 2); - - *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, - offsetof(struct sock, sk_family)); + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_family), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, + skc_family, + FIELD_SIZEOF(struct sock_common, + skc_family), + target_size)); break; case offsetof(struct bpf_sock, type): + BUILD_BUG_ON(HWEIGHT32(SK_FL_TYPE_MASK) != BITS_PER_BYTE * 2); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, offsetof(struct sock, __sk_flags_offset)); *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_TYPE_MASK); *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, SK_FL_TYPE_SHIFT); + *target_size = 2; break; case offsetof(struct bpf_sock, protocol): + BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, offsetof(struct sock, __sk_flags_offset)); *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK); *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, SK_FL_PROTO_SHIFT); + *target_size = 1; break; case offsetof(struct bpf_sock, src_ip4): @@ -6794,6 +7078,15 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, target_size)); break; + case offsetof(struct bpf_sock, dst_ip4): + *insn++ = BPF_LDX_MEM( + BPF_SIZE(si->code), si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_daddr, + FIELD_SIZEOF(struct sock_common, + skc_daddr), + target_size)); + break; + case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): #if IS_ENABLED(CONFIG_IPV6) off = si->off; @@ -6812,6 +7105,23 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, #endif break; + case bpf_ctx_range_till(struct bpf_sock, dst_ip6[0], dst_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + off = si->off; + off -= offsetof(struct bpf_sock, dst_ip6[0]); + *insn++ = BPF_LDX_MEM( + BPF_SIZE(si->code), si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, + skc_v6_daddr.s6_addr32[0], + FIELD_SIZEOF(struct sock_common, + skc_v6_daddr.s6_addr32[0]), + target_size) + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); + *target_size = 4; +#endif + break; + case offsetof(struct bpf_sock, src_port): *insn++ = BPF_LDX_MEM( BPF_FIELD_SIZEOF(struct sock_common, skc_num), @@ -6821,6 +7131,26 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, skc_num), target_size)); break; + + case offsetof(struct bpf_sock, dst_port): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_dport), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_dport, + FIELD_SIZEOF(struct sock_common, + skc_dport), + target_size)); + break; + + case offsetof(struct bpf_sock, state): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_state), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_state, + FIELD_SIZEOF(struct sock_common, + skc_state), + target_size)); + break; } return insn - insn_buf; @@ -7068,6 +7398,85 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type, struct bpf_insn *insn = insn_buf; int off; +/* Helper macro for adding read access to tcp_sock or sock fields. */ +#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ) \ + do { \ + BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) > \ + FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, \ + is_fullsock), \ + si->dst_reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + is_fullsock)); \ + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 2); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, sk),\ + si->dst_reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, sk));\ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ, \ + OBJ_FIELD), \ + si->dst_reg, si->dst_reg, \ + offsetof(OBJ, OBJ_FIELD)); \ + } while (0) + +#define SOCK_OPS_GET_TCP_SOCK_FIELD(FIELD) \ + SOCK_OPS_GET_FIELD(FIELD, FIELD, struct tcp_sock) + +/* Helper macro for adding write access to tcp_sock or sock fields. + * The macro is called with two registers, dst_reg which contains a pointer + * to ctx (context) and src_reg which contains the value that should be + * stored. However, we need an additional register since we cannot overwrite + * dst_reg because it may be used later in the program. + * Instead we "borrow" one of the other register. We first save its value + * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore + * it at the end of the macro. + */ +#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ) \ + do { \ + int reg = BPF_REG_9; \ + BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) > \ + FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD)); \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, \ + is_fullsock), \ + reg, si->dst_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + is_fullsock)); \ + *insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, sk),\ + reg, si->dst_reg, \ + offsetof(struct bpf_sock_ops_kern, sk));\ + *insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD), \ + reg, si->src_reg, \ + offsetof(OBJ, OBJ_FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + } while (0) + +#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE) \ + do { \ + if (TYPE == BPF_WRITE) \ + SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ); \ + else \ + SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ); \ + } while (0) + + CONVERT_COMMON_TCP_SOCK_FIELDS(struct bpf_sock_ops, + SOCK_OPS_GET_TCP_SOCK_FIELD); + + if (insn > insn_buf) + return insn - insn_buf; + switch (si->off) { case offsetof(struct bpf_sock_ops, op) ... offsetof(struct bpf_sock_ops, replylong[3]): @@ -7225,175 +7634,15 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type, FIELD_SIZEOF(struct minmax_sample, t)); break; -/* Helper macro for adding read access to tcp_sock or sock fields. */ -#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ) \ - do { \ - BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) > \ - FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD)); \ - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ - struct bpf_sock_ops_kern, \ - is_fullsock), \ - si->dst_reg, si->src_reg, \ - offsetof(struct bpf_sock_ops_kern, \ - is_fullsock)); \ - *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 2); \ - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ - struct bpf_sock_ops_kern, sk),\ - si->dst_reg, si->src_reg, \ - offsetof(struct bpf_sock_ops_kern, sk));\ - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ, \ - OBJ_FIELD), \ - si->dst_reg, si->dst_reg, \ - offsetof(OBJ, OBJ_FIELD)); \ - } while (0) - -/* Helper macro for adding write access to tcp_sock or sock fields. - * The macro is called with two registers, dst_reg which contains a pointer - * to ctx (context) and src_reg which contains the value that should be - * stored. However, we need an additional register since we cannot overwrite - * dst_reg because it may be used later in the program. - * Instead we "borrow" one of the other register. We first save its value - * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore - * it at the end of the macro. - */ -#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ) \ - do { \ - int reg = BPF_REG_9; \ - BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) > \ - FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD)); \ - if (si->dst_reg == reg || si->src_reg == reg) \ - reg--; \ - if (si->dst_reg == reg || si->src_reg == reg) \ - reg--; \ - *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg, \ - offsetof(struct bpf_sock_ops_kern, \ - temp)); \ - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ - struct bpf_sock_ops_kern, \ - is_fullsock), \ - reg, si->dst_reg, \ - offsetof(struct bpf_sock_ops_kern, \ - is_fullsock)); \ - *insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2); \ - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ - struct bpf_sock_ops_kern, sk),\ - reg, si->dst_reg, \ - offsetof(struct bpf_sock_ops_kern, sk));\ - *insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD), \ - reg, si->src_reg, \ - offsetof(OBJ, OBJ_FIELD)); \ - *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg, \ - offsetof(struct bpf_sock_ops_kern, \ - temp)); \ - } while (0) - -#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE) \ - do { \ - if (TYPE == BPF_WRITE) \ - SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ); \ - else \ - SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ); \ - } while (0) - - case offsetof(struct bpf_sock_ops, snd_cwnd): - SOCK_OPS_GET_FIELD(snd_cwnd, snd_cwnd, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, srtt_us): - SOCK_OPS_GET_FIELD(srtt_us, srtt_us, struct tcp_sock); - break; - case offsetof(struct bpf_sock_ops, bpf_sock_ops_cb_flags): SOCK_OPS_GET_FIELD(bpf_sock_ops_cb_flags, bpf_sock_ops_cb_flags, struct tcp_sock); break; - case offsetof(struct bpf_sock_ops, snd_ssthresh): - SOCK_OPS_GET_FIELD(snd_ssthresh, snd_ssthresh, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, rcv_nxt): - SOCK_OPS_GET_FIELD(rcv_nxt, rcv_nxt, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, snd_nxt): - SOCK_OPS_GET_FIELD(snd_nxt, snd_nxt, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, snd_una): - SOCK_OPS_GET_FIELD(snd_una, snd_una, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, mss_cache): - SOCK_OPS_GET_FIELD(mss_cache, mss_cache, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, ecn_flags): - SOCK_OPS_GET_FIELD(ecn_flags, ecn_flags, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, rate_delivered): - SOCK_OPS_GET_FIELD(rate_delivered, rate_delivered, - struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, rate_interval_us): - SOCK_OPS_GET_FIELD(rate_interval_us, rate_interval_us, - struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, packets_out): - SOCK_OPS_GET_FIELD(packets_out, packets_out, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, retrans_out): - SOCK_OPS_GET_FIELD(retrans_out, retrans_out, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, total_retrans): - SOCK_OPS_GET_FIELD(total_retrans, total_retrans, - struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, segs_in): - SOCK_OPS_GET_FIELD(segs_in, segs_in, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, data_segs_in): - SOCK_OPS_GET_FIELD(data_segs_in, data_segs_in, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, segs_out): - SOCK_OPS_GET_FIELD(segs_out, segs_out, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, data_segs_out): - SOCK_OPS_GET_FIELD(data_segs_out, data_segs_out, - struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, lost_out): - SOCK_OPS_GET_FIELD(lost_out, lost_out, struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, sacked_out): - SOCK_OPS_GET_FIELD(sacked_out, sacked_out, struct tcp_sock); - break; - case offsetof(struct bpf_sock_ops, sk_txhash): SOCK_OPS_GET_OR_SET_FIELD(sk_txhash, sk_txhash, struct sock, type); break; - - case offsetof(struct bpf_sock_ops, bytes_received): - SOCK_OPS_GET_FIELD(bytes_received, bytes_received, - struct tcp_sock); - break; - - case offsetof(struct bpf_sock_ops, bytes_acked): - SOCK_OPS_GET_FIELD(bytes_acked, bytes_acked, struct tcp_sock); - break; - } return insn - insn_buf; } @@ -7698,6 +7947,7 @@ const struct bpf_verifier_ops flow_dissector_verifier_ops = { }; const struct bpf_prog_ops flow_dissector_prog_ops = { + .test_run = bpf_prog_test_run_flow_dissector, }; int sk_detach_filter(struct sock *sk) diff --git a/net/core/flow_dissector.c b/net/core/flow_dissector.c index 9f2840510e63..bb1a54747d64 100644 --- a/net/core/flow_dissector.c +++ b/net/core/flow_dissector.c @@ -683,6 +683,46 @@ static void __skb_flow_bpf_to_target(const struct bpf_flow_keys *flow_keys, } } +bool __skb_flow_bpf_dissect(struct bpf_prog *prog, + const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + struct bpf_flow_keys *flow_keys) +{ + struct bpf_skb_data_end cb_saved; + struct bpf_skb_data_end *cb; + u32 result; + + /* Note that even though the const qualifier is discarded + * throughout the execution of the BPF program, all changes(the + * control block) are reverted after the BPF program returns. + * Therefore, __skb_flow_dissect does not alter the skb. + */ + + cb = (struct bpf_skb_data_end *)skb->cb; + + /* Save Control Block */ + memcpy(&cb_saved, cb, sizeof(cb_saved)); + memset(cb, 0, sizeof(*cb)); + + /* Pass parameters to the BPF program */ + memset(flow_keys, 0, sizeof(*flow_keys)); + cb->qdisc_cb.flow_keys = flow_keys; + flow_keys->nhoff = skb_network_offset(skb); + flow_keys->thoff = flow_keys->nhoff; + + bpf_compute_data_pointers((struct sk_buff *)skb); + result = BPF_PROG_RUN(prog, skb); + + /* Restore state */ + memcpy(cb, &cb_saved, sizeof(cb_saved)); + + flow_keys->nhoff = clamp_t(u16, flow_keys->nhoff, 0, skb->len); + flow_keys->thoff = clamp_t(u16, flow_keys->thoff, + flow_keys->nhoff, skb->len); + + return result == BPF_OK; +} + /** * __skb_flow_dissect - extract the flow_keys struct and return it * @skb: sk_buff to extract the flow from, can be NULL if the rest are specified @@ -714,7 +754,6 @@ bool __skb_flow_dissect(const struct sk_buff *skb, struct flow_dissector_key_vlan *key_vlan; enum flow_dissect_ret fdret; enum flow_dissector_key_id dissector_vlan = FLOW_DISSECTOR_KEY_MAX; - struct bpf_prog *attached = NULL; int num_hdrs = 0; u8 ip_proto = 0; bool ret; @@ -754,53 +793,30 @@ bool __skb_flow_dissect(const struct sk_buff *skb, FLOW_DISSECTOR_KEY_BASIC, target_container); - rcu_read_lock(); if (skb) { + struct bpf_flow_keys flow_keys; + struct bpf_prog *attached = NULL; + + rcu_read_lock(); + if (skb->dev) attached = rcu_dereference(dev_net(skb->dev)->flow_dissector_prog); else if (skb->sk) attached = rcu_dereference(sock_net(skb->sk)->flow_dissector_prog); else WARN_ON_ONCE(1); - } - if (attached) { - /* Note that even though the const qualifier is discarded - * throughout the execution of the BPF program, all changes(the - * control block) are reverted after the BPF program returns. - * Therefore, __skb_flow_dissect does not alter the skb. - */ - struct bpf_flow_keys flow_keys = {}; - struct bpf_skb_data_end cb_saved; - struct bpf_skb_data_end *cb; - u32 result; - - cb = (struct bpf_skb_data_end *)skb->cb; - - /* Save Control Block */ - memcpy(&cb_saved, cb, sizeof(cb_saved)); - memset(cb, 0, sizeof(cb_saved)); - /* Pass parameters to the BPF program */ - cb->qdisc_cb.flow_keys = &flow_keys; - flow_keys.nhoff = nhoff; - flow_keys.thoff = nhoff; - - bpf_compute_data_pointers((struct sk_buff *)skb); - result = BPF_PROG_RUN(attached, skb); - - /* Restore state */ - memcpy(cb, &cb_saved, sizeof(cb_saved)); - - flow_keys.nhoff = clamp_t(u16, flow_keys.nhoff, 0, skb->len); - flow_keys.thoff = clamp_t(u16, flow_keys.thoff, - flow_keys.nhoff, skb->len); - - __skb_flow_bpf_to_target(&flow_keys, flow_dissector, - target_container); + if (attached) { + ret = __skb_flow_bpf_dissect(attached, skb, + flow_dissector, + &flow_keys); + __skb_flow_bpf_to_target(&flow_keys, flow_dissector, + target_container); + rcu_read_unlock(); + return ret; + } rcu_read_unlock(); - return result == BPF_OK; } - rcu_read_unlock(); if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ETH_ADDRS)) { diff --git a/net/core/flow_offload.c b/net/core/flow_offload.c new file mode 100644 index 000000000000..c3a00eac4804 --- /dev/null +++ b/net/core/flow_offload.c @@ -0,0 +1,153 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include <linux/kernel.h> +#include <linux/slab.h> +#include <net/flow_offload.h> + +struct flow_rule *flow_rule_alloc(unsigned int num_actions) +{ + struct flow_rule *rule; + + rule = kzalloc(sizeof(struct flow_rule) + + sizeof(struct flow_action_entry) * num_actions, + GFP_KERNEL); + if (!rule) + return NULL; + + rule->action.num_entries = num_actions; + + return rule; +} +EXPORT_SYMBOL(flow_rule_alloc); + +#define FLOW_DISSECTOR_MATCH(__rule, __type, __out) \ + const struct flow_match *__m = &(__rule)->match; \ + struct flow_dissector *__d = (__m)->dissector; \ + \ + (__out)->key = skb_flow_dissector_target(__d, __type, (__m)->key); \ + (__out)->mask = skb_flow_dissector_target(__d, __type, (__m)->mask); \ + +void flow_rule_match_basic(const struct flow_rule *rule, + struct flow_match_basic *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_BASIC, out); +} +EXPORT_SYMBOL(flow_rule_match_basic); + +void flow_rule_match_control(const struct flow_rule *rule, + struct flow_match_control *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_CONTROL, out); +} +EXPORT_SYMBOL(flow_rule_match_control); + +void flow_rule_match_eth_addrs(const struct flow_rule *rule, + struct flow_match_eth_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ETH_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_eth_addrs); + +void flow_rule_match_vlan(const struct flow_rule *rule, + struct flow_match_vlan *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_VLAN, out); +} +EXPORT_SYMBOL(flow_rule_match_vlan); + +void flow_rule_match_ipv4_addrs(const struct flow_rule *rule, + struct flow_match_ipv4_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_IPV4_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_ipv4_addrs); + +void flow_rule_match_ipv6_addrs(const struct flow_rule *rule, + struct flow_match_ipv6_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_IPV6_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_ipv6_addrs); + +void flow_rule_match_ip(const struct flow_rule *rule, + struct flow_match_ip *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_IP, out); +} +EXPORT_SYMBOL(flow_rule_match_ip); + +void flow_rule_match_ports(const struct flow_rule *rule, + struct flow_match_ports *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_PORTS, out); +} +EXPORT_SYMBOL(flow_rule_match_ports); + +void flow_rule_match_tcp(const struct flow_rule *rule, + struct flow_match_tcp *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_TCP, out); +} +EXPORT_SYMBOL(flow_rule_match_tcp); + +void flow_rule_match_icmp(const struct flow_rule *rule, + struct flow_match_icmp *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ICMP, out); +} +EXPORT_SYMBOL(flow_rule_match_icmp); + +void flow_rule_match_mpls(const struct flow_rule *rule, + struct flow_match_mpls *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_MPLS, out); +} +EXPORT_SYMBOL(flow_rule_match_mpls); + +void flow_rule_match_enc_control(const struct flow_rule *rule, + struct flow_match_control *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_CONTROL, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_control); + +void flow_rule_match_enc_ipv4_addrs(const struct flow_rule *rule, + struct flow_match_ipv4_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ipv4_addrs); + +void flow_rule_match_enc_ipv6_addrs(const struct flow_rule *rule, + struct flow_match_ipv6_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ipv6_addrs); + +void flow_rule_match_enc_ip(const struct flow_rule *rule, + struct flow_match_ip *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_IP, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ip); + +void flow_rule_match_enc_ports(const struct flow_rule *rule, + struct flow_match_ports *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_PORTS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ports); + +void flow_rule_match_enc_keyid(const struct flow_rule *rule, + struct flow_match_enc_keyid *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_KEYID, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_keyid); + +void flow_rule_match_enc_opts(const struct flow_rule *rule, + struct flow_match_enc_opts *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_OPTS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_opts); diff --git a/net/core/gen_stats.c b/net/core/gen_stats.c index 9bf1b9ad1780..ac679f74ba47 100644 --- a/net/core/gen_stats.c +++ b/net/core/gen_stats.c @@ -291,7 +291,6 @@ __gnet_stats_copy_queue_cpu(struct gnet_stats_queue *qstats, for_each_possible_cpu(i) { const struct gnet_stats_queue *qcpu = per_cpu_ptr(q, i); - qstats->qlen = 0; qstats->backlog += qcpu->backlog; qstats->drops += qcpu->drops; qstats->requeues += qcpu->requeues; @@ -307,7 +306,6 @@ void __gnet_stats_copy_queue(struct gnet_stats_queue *qstats, if (cpu) { __gnet_stats_copy_queue_cpu(qstats, cpu); } else { - qstats->qlen = q->qlen; qstats->backlog = q->backlog; qstats->drops = q->drops; qstats->requeues = q->requeues; diff --git a/net/core/gro_cells.c b/net/core/gro_cells.c index acf45ddbe924..e095fb871d91 100644 --- a/net/core/gro_cells.c +++ b/net/core/gro_cells.c @@ -13,22 +13,36 @@ int gro_cells_receive(struct gro_cells *gcells, struct sk_buff *skb) { struct net_device *dev = skb->dev; struct gro_cell *cell; + int res; - if (!gcells->cells || skb_cloned(skb) || netif_elide_gro(dev)) - return netif_rx(skb); + rcu_read_lock(); + if (unlikely(!(dev->flags & IFF_UP))) + goto drop; + + if (!gcells->cells || skb_cloned(skb) || netif_elide_gro(dev)) { + res = netif_rx(skb); + goto unlock; + } cell = this_cpu_ptr(gcells->cells); if (skb_queue_len(&cell->napi_skbs) > netdev_max_backlog) { +drop: atomic_long_inc(&dev->rx_dropped); kfree_skb(skb); - return NET_RX_DROP; + res = NET_RX_DROP; + goto unlock; } __skb_queue_tail(&cell->napi_skbs, skb); if (skb_queue_len(&cell->napi_skbs) == 1) napi_schedule(&cell->napi); - return NET_RX_SUCCESS; + + res = NET_RX_SUCCESS; + +unlock: + rcu_read_unlock(); + return res; } EXPORT_SYMBOL(gro_cells_receive); diff --git a/net/core/lwt_bpf.c b/net/core/lwt_bpf.c index a648568c5e8f..126d31ff5ee3 100644 --- a/net/core/lwt_bpf.c +++ b/net/core/lwt_bpf.c @@ -16,6 +16,8 @@ #include <linux/types.h> #include <linux/bpf.h> #include <net/lwtunnel.h> +#include <net/gre.h> +#include <net/ip6_route.h> struct bpf_lwt_prog { struct bpf_prog *prog; @@ -55,6 +57,7 @@ static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt, switch (ret) { case BPF_OK: + case BPF_LWT_REROUTE: break; case BPF_REDIRECT: @@ -87,6 +90,30 @@ static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt, return ret; } +static int bpf_lwt_input_reroute(struct sk_buff *skb) +{ + int err = -EINVAL; + + if (skb->protocol == htons(ETH_P_IP)) { + struct iphdr *iph = ip_hdr(skb); + + err = ip_route_input_noref(skb, iph->daddr, iph->saddr, + iph->tos, skb_dst(skb)->dev); + } else if (skb->protocol == htons(ETH_P_IPV6)) { + err = ipv6_stub->ipv6_route_input(skb); + } else { + err = -EAFNOSUPPORT; + } + + if (err) + goto err; + return dst_input(skb); + +err: + kfree_skb(skb); + return err; +} + static int bpf_input(struct sk_buff *skb) { struct dst_entry *dst = skb_dst(skb); @@ -98,11 +125,11 @@ static int bpf_input(struct sk_buff *skb) ret = run_lwt_bpf(skb, &bpf->in, dst, NO_REDIRECT); if (ret < 0) return ret; + if (ret == BPF_LWT_REROUTE) + return bpf_lwt_input_reroute(skb); } if (unlikely(!dst->lwtstate->orig_input)) { - pr_warn_once("orig_input not set on dst for prog %s\n", - bpf->out.name); kfree_skb(skb); return -EINVAL; } @@ -147,6 +174,102 @@ static int xmit_check_hhlen(struct sk_buff *skb) return 0; } +static int bpf_lwt_xmit_reroute(struct sk_buff *skb) +{ + struct net_device *l3mdev = l3mdev_master_dev_rcu(skb_dst(skb)->dev); + int oif = l3mdev ? l3mdev->ifindex : 0; + struct dst_entry *dst = NULL; + int err = -EAFNOSUPPORT; + struct sock *sk; + struct net *net; + bool ipv4; + + if (skb->protocol == htons(ETH_P_IP)) + ipv4 = true; + else if (skb->protocol == htons(ETH_P_IPV6)) + ipv4 = false; + else + goto err; + + sk = sk_to_full_sk(skb->sk); + if (sk) { + if (sk->sk_bound_dev_if) + oif = sk->sk_bound_dev_if; + net = sock_net(sk); + } else { + net = dev_net(skb_dst(skb)->dev); + } + + if (ipv4) { + struct iphdr *iph = ip_hdr(skb); + struct flowi4 fl4 = {}; + struct rtable *rt; + + fl4.flowi4_oif = oif; + fl4.flowi4_mark = skb->mark; + fl4.flowi4_uid = sock_net_uid(net, sk); + fl4.flowi4_tos = RT_TOS(iph->tos); + fl4.flowi4_flags = FLOWI_FLAG_ANYSRC; + fl4.flowi4_proto = iph->protocol; + fl4.daddr = iph->daddr; + fl4.saddr = iph->saddr; + + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + goto err; + } + dst = &rt->dst; + } else { + struct ipv6hdr *iph6 = ipv6_hdr(skb); + struct flowi6 fl6 = {}; + + fl6.flowi6_oif = oif; + fl6.flowi6_mark = skb->mark; + fl6.flowi6_uid = sock_net_uid(net, sk); + fl6.flowlabel = ip6_flowinfo(iph6); + fl6.flowi6_proto = iph6->nexthdr; + fl6.daddr = iph6->daddr; + fl6.saddr = iph6->saddr; + + err = ipv6_stub->ipv6_dst_lookup(net, skb->sk, &dst, &fl6); + if (unlikely(err)) + goto err; + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto err; + } + } + if (unlikely(dst->error)) { + err = dst->error; + dst_release(dst); + goto err; + } + + /* Although skb header was reserved in bpf_lwt_push_ip_encap(), it + * was done for the previous dst, so we are doing it here again, in + * case the new dst needs much more space. The call below is a noop + * if there is enough header space in skb. + */ + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + goto err; + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + err = dst_output(dev_net(skb_dst(skb)->dev), skb->sk, skb); + if (unlikely(err)) + return err; + + /* ip[6]_finish_output2 understand LWTUNNEL_XMIT_DONE */ + return LWTUNNEL_XMIT_DONE; + +err: + kfree_skb(skb); + return err; +} + static int bpf_xmit(struct sk_buff *skb) { struct dst_entry *dst = skb_dst(skb); @@ -154,11 +277,20 @@ static int bpf_xmit(struct sk_buff *skb) bpf = bpf_lwt_lwtunnel(dst->lwtstate); if (bpf->xmit.prog) { + __be16 proto = skb->protocol; int ret; ret = run_lwt_bpf(skb, &bpf->xmit, dst, CAN_REDIRECT); switch (ret) { case BPF_OK: + /* If the header changed, e.g. via bpf_lwt_push_encap, + * BPF_LWT_REROUTE below should have been used if the + * protocol was also changed. + */ + if (skb->protocol != proto) { + kfree_skb(skb); + return -EINVAL; + } /* If the header was expanded, headroom might be too * small for L2 header to come, expand as needed. */ @@ -169,6 +301,8 @@ static int bpf_xmit(struct sk_buff *skb) return LWTUNNEL_XMIT_CONTINUE; case BPF_REDIRECT: return LWTUNNEL_XMIT_DONE; + case BPF_LWT_REROUTE: + return bpf_lwt_xmit_reroute(skb); default: return ret; } @@ -390,6 +524,135 @@ static const struct lwtunnel_encap_ops bpf_encap_ops = { .owner = THIS_MODULE, }; +static int handle_gso_type(struct sk_buff *skb, unsigned int gso_type, + int encap_len) +{ + struct skb_shared_info *shinfo = skb_shinfo(skb); + + gso_type |= SKB_GSO_DODGY; + shinfo->gso_type |= gso_type; + skb_decrease_gso_size(shinfo, encap_len); + shinfo->gso_segs = 0; + return 0; +} + +static int handle_gso_encap(struct sk_buff *skb, bool ipv4, int encap_len) +{ + int next_hdr_offset; + void *next_hdr; + __u8 protocol; + + /* SCTP and UDP_L4 gso need more nuanced handling than what + * handle_gso_type() does above: skb_decrease_gso_size() is not enough. + * So at the moment only TCP GSO packets are let through. + */ + if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6))) + return -ENOTSUPP; + + if (ipv4) { + protocol = ip_hdr(skb)->protocol; + next_hdr_offset = sizeof(struct iphdr); + next_hdr = skb_network_header(skb) + next_hdr_offset; + } else { + protocol = ipv6_hdr(skb)->nexthdr; + next_hdr_offset = sizeof(struct ipv6hdr); + next_hdr = skb_network_header(skb) + next_hdr_offset; + } + + switch (protocol) { + case IPPROTO_GRE: + next_hdr_offset += sizeof(struct gre_base_hdr); + if (next_hdr_offset > encap_len) + return -EINVAL; + + if (((struct gre_base_hdr *)next_hdr)->flags & GRE_CSUM) + return handle_gso_type(skb, SKB_GSO_GRE_CSUM, + encap_len); + return handle_gso_type(skb, SKB_GSO_GRE, encap_len); + + case IPPROTO_UDP: + next_hdr_offset += sizeof(struct udphdr); + if (next_hdr_offset > encap_len) + return -EINVAL; + + if (((struct udphdr *)next_hdr)->check) + return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL_CSUM, + encap_len); + return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL, encap_len); + + case IPPROTO_IP: + case IPPROTO_IPV6: + if (ipv4) + return handle_gso_type(skb, SKB_GSO_IPXIP4, encap_len); + else + return handle_gso_type(skb, SKB_GSO_IPXIP6, encap_len); + + default: + return -EPROTONOSUPPORT; + } +} + +int bpf_lwt_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, bool ingress) +{ + struct iphdr *iph; + bool ipv4; + int err; + + if (unlikely(len < sizeof(struct iphdr) || len > LWT_BPF_MAX_HEADROOM)) + return -EINVAL; + + /* validate protocol and length */ + iph = (struct iphdr *)hdr; + if (iph->version == 4) { + ipv4 = true; + if (unlikely(len < iph->ihl * 4)) + return -EINVAL; + } else if (iph->version == 6) { + ipv4 = false; + if (unlikely(len < sizeof(struct ipv6hdr))) + return -EINVAL; + } else { + return -EINVAL; + } + + if (ingress) + err = skb_cow_head(skb, len + skb->mac_len); + else + err = skb_cow_head(skb, + len + LL_RESERVED_SPACE(skb_dst(skb)->dev)); + if (unlikely(err)) + return err; + + /* push the encap headers and fix pointers */ + skb_reset_inner_headers(skb); + skb_reset_inner_mac_header(skb); /* mac header is not yet set */ + skb_set_inner_protocol(skb, skb->protocol); + skb->encapsulation = 1; + skb_push(skb, len); + if (ingress) + skb_postpush_rcsum(skb, iph, len); + skb_reset_network_header(skb); + memcpy(skb_network_header(skb), hdr, len); + bpf_compute_data_pointers(skb); + skb_clear_hash(skb); + + if (ipv4) { + skb->protocol = htons(ETH_P_IP); + iph = ip_hdr(skb); + + if (!iph->check) + iph->check = ip_fast_csum((unsigned char *)iph, + iph->ihl); + } else { + skb->protocol = htons(ETH_P_IPV6); + } + + if (skb_is_gso(skb)) + return handle_gso_encap(skb, ipv4, len); + + return 0; +} + static int __init bpf_lwt_init(void) { return lwtunnel_encap_add_ops(&bpf_encap_ops, LWTUNNEL_ENCAP_BPF); diff --git a/net/core/lwtunnel.c b/net/core/lwtunnel.c index 0b171756453c..19b557bd294b 100644 --- a/net/core/lwtunnel.c +++ b/net/core/lwtunnel.c @@ -122,18 +122,18 @@ int lwtunnel_build_state(u16 encap_type, ret = -EOPNOTSUPP; rcu_read_lock(); ops = rcu_dereference(lwtun_encaps[encap_type]); - if (likely(ops && ops->build_state && try_module_get(ops->owner))) { + if (likely(ops && ops->build_state && try_module_get(ops->owner))) found = true; + rcu_read_unlock(); + + if (found) { ret = ops->build_state(encap, family, cfg, lws, extack); if (ret) module_put(ops->owner); - } - rcu_read_unlock(); - - /* don't rely on -EOPNOTSUPP to detect match as build_state - * handlers could return it - */ - if (!found) { + } else { + /* don't rely on -EOPNOTSUPP to detect match as build_state + * handlers could return it + */ NL_SET_ERR_MSG_ATTR(extack, encap, "LWT encapsulation type not supported"); } diff --git a/net/core/neighbour.c b/net/core/neighbour.c index 4230400b9a30..30f6fd8f68e0 100644 --- a/net/core/neighbour.c +++ b/net/core/neighbour.c @@ -42,6 +42,8 @@ #include <linux/inetdevice.h> #include <net/addrconf.h> +#include <trace/events/neigh.h> + #define DEBUG #define NEIGH_DEBUG 1 #define neigh_dbg(level, fmt, ...) \ @@ -102,6 +104,7 @@ static void neigh_cleanup_and_release(struct neighbour *neigh) if (neigh->parms->neigh_cleanup) neigh->parms->neigh_cleanup(neigh); + trace_neigh_cleanup_and_release(neigh, 0); __neigh_notify(neigh, RTM_DELNEIGH, 0, 0); call_netevent_notifiers(NETEVENT_NEIGH_UPDATE, neigh); neigh_release(neigh); @@ -1095,6 +1098,8 @@ out: if (notify) neigh_update_notify(neigh, 0); + trace_neigh_timer_handler(neigh, 0); + neigh_release(neigh); } @@ -1165,6 +1170,7 @@ out_unlock_bh: else write_unlock(&neigh->lock); local_bh_enable(); + trace_neigh_event_send_done(neigh, rc); return rc; out_dead: @@ -1172,6 +1178,7 @@ out_dead: goto out_unlock_bh; write_unlock_bh(&neigh->lock); kfree_skb(skb); + trace_neigh_event_send_dead(neigh, 1); return 1; } EXPORT_SYMBOL(__neigh_event_send); @@ -1227,6 +1234,8 @@ static int __neigh_update(struct neighbour *neigh, const u8 *lladdr, struct net_device *dev; int update_isrouter = 0; + trace_neigh_update(neigh, lladdr, new, flags, nlmsg_pid); + write_lock_bh(&neigh->lock); dev = neigh->dev; @@ -1393,6 +1402,8 @@ out: if (notify) neigh_update_notify(neigh, nlmsg_pid); + trace_neigh_update_done(neigh, err); + return err; } diff --git a/net/core/net-sysfs.c b/net/core/net-sysfs.c index ff9fd2bb4ce4..4ff661f6f989 100644 --- a/net/core/net-sysfs.c +++ b/net/core/net-sysfs.c @@ -12,7 +12,6 @@ #include <linux/capability.h> #include <linux/kernel.h> #include <linux/netdevice.h> -#include <net/switchdev.h> #include <linux/if_arp.h> #include <linux/slab.h> #include <linux/sched/signal.h> @@ -501,16 +500,11 @@ static ssize_t phys_switch_id_show(struct device *dev, return restart_syscall(); if (dev_isalive(netdev)) { - struct switchdev_attr attr = { - .orig_dev = netdev, - .id = SWITCHDEV_ATTR_ID_PORT_PARENT_ID, - .flags = SWITCHDEV_F_NO_RECURSE, - }; + struct netdev_phys_item_id ppid = { }; - ret = switchdev_port_attr_get(netdev, &attr); + ret = dev_get_port_parent_id(netdev, &ppid, false); if (!ret) - ret = sprintf(buf, "%*phN\n", attr.u.ppid.id_len, - attr.u.ppid.id); + ret = sprintf(buf, "%*phN\n", ppid.id_len, ppid.id); } rtnl_unlock(); @@ -1342,8 +1336,7 @@ static ssize_t xps_rxqs_show(struct netdev_queue *queue, char *buf) if (tc < 0) return -EINVAL; } - mask = kcalloc(BITS_TO_LONGS(dev->num_rx_queues), sizeof(long), - GFP_KERNEL); + mask = bitmap_zalloc(dev->num_rx_queues, GFP_KERNEL); if (!mask) return -ENOMEM; @@ -1372,7 +1365,7 @@ out_no_maps: rcu_read_unlock(); len = bitmap_print_to_pagebuf(false, buf, mask, dev->num_rx_queues); - kfree(mask); + bitmap_free(mask); return len < PAGE_SIZE ? len : -EINVAL; } @@ -1388,8 +1381,7 @@ static ssize_t xps_rxqs_store(struct netdev_queue *queue, const char *buf, if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) return -EPERM; - mask = kcalloc(BITS_TO_LONGS(dev->num_rx_queues), sizeof(long), - GFP_KERNEL); + mask = bitmap_zalloc(dev->num_rx_queues, GFP_KERNEL); if (!mask) return -ENOMEM; @@ -1397,7 +1389,7 @@ static ssize_t xps_rxqs_store(struct netdev_queue *queue, const char *buf, err = bitmap_parse(buf, len, mask, dev->num_rx_queues); if (err) { - kfree(mask); + bitmap_free(mask); return err; } @@ -1405,7 +1397,7 @@ static ssize_t xps_rxqs_store(struct netdev_queue *queue, const char *buf, err = __netif_set_xps_queue(dev, mask, index, true); cpus_read_unlock(); - kfree(mask); + bitmap_free(mask); return err ? : len; } @@ -1547,6 +1539,9 @@ static int register_queue_kobjects(struct net_device *dev) error: netdev_queue_update_kobjects(dev, txq, 0); net_rx_queue_update_kobjects(dev, rxq, 0); +#ifdef CONFIG_SYSFS + kset_unregister(dev->queues_kset); +#endif return error; } diff --git a/net/core/net-traces.c b/net/core/net-traces.c index 419af6dfe29f..470b179d599e 100644 --- a/net/core/net-traces.c +++ b/net/core/net-traces.c @@ -43,6 +43,14 @@ EXPORT_TRACEPOINT_SYMBOL_GPL(fdb_delete); EXPORT_TRACEPOINT_SYMBOL_GPL(br_fdb_update); #endif +#include <trace/events/neigh.h> +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_update); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_update_done); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_timer_handler); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_event_send_done); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_event_send_dead); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_cleanup_and_release); + EXPORT_TRACEPOINT_SYMBOL_GPL(kfree_skb); EXPORT_TRACEPOINT_SYMBOL_GPL(napi_poll); diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index b02fb19df2cc..17f36317363d 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -778,6 +778,41 @@ nla_put_failure: return -EMSGSIZE; } +static int rtnl_net_valid_getid_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX, + rtnl_net_policy, extack); + + err = nlmsg_parse_strict(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX, + rtnl_net_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETNSA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETNSA_PID: + case NETNSA_FD: + case NETNSA_NSID: + case NETNSA_TARGET_NSID: + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in peer netns getid request"); + return -EINVAL; + } + } + + return 0; +} + static int rtnl_net_getid(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -793,8 +828,7 @@ static int rtnl_net_getid(struct sk_buff *skb, struct nlmsghdr *nlh, struct sk_buff *msg; int err; - err = nlmsg_parse(nlh, sizeof(struct rtgenmsg), tb, NETNSA_MAX, - rtnl_net_policy, extack); + err = rtnl_net_valid_getid_req(skb, nlh, tb, extack); if (err < 0) return err; if (tb[NETNSA_PID]) { diff --git a/net/core/page_pool.c b/net/core/page_pool.c index 43a932cb609b..5b2252c6d49b 100644 --- a/net/core/page_pool.c +++ b/net/core/page_pool.c @@ -136,17 +136,19 @@ static struct page *__page_pool_alloc_pages_slow(struct page_pool *pool, if (!(pool->p.flags & PP_FLAG_DMA_MAP)) goto skip_dma_map; - /* Setup DMA mapping: use page->private for DMA-addr + /* Setup DMA mapping: use 'struct page' area for storing DMA-addr + * since dma_addr_t can be either 32 or 64 bits and does not always fit + * into page private data (i.e 32bit cpu with 64bit DMA caps) * This mapping is kept for lifetime of page, until leaving pool. */ - dma = dma_map_page(pool->p.dev, page, 0, - (PAGE_SIZE << pool->p.order), - pool->p.dma_dir); + dma = dma_map_page_attrs(pool->p.dev, page, 0, + (PAGE_SIZE << pool->p.order), + pool->p.dma_dir, DMA_ATTR_SKIP_CPU_SYNC); if (dma_mapping_error(pool->p.dev, dma)) { put_page(page); return NULL; } - set_page_private(page, dma); /* page->private = dma; */ + page->dma_addr = dma; skip_dma_map: /* When page just alloc'ed is should/must have refcnt 1. */ @@ -175,13 +177,17 @@ EXPORT_SYMBOL(page_pool_alloc_pages); static void __page_pool_clean_page(struct page_pool *pool, struct page *page) { + dma_addr_t dma; + if (!(pool->p.flags & PP_FLAG_DMA_MAP)) return; + dma = page->dma_addr; /* DMA unmap */ - dma_unmap_page(pool->p.dev, page_private(page), - PAGE_SIZE << pool->p.order, pool->p.dma_dir); - set_page_private(page, 0); + dma_unmap_page_attrs(pool->p.dev, dma, + PAGE_SIZE << pool->p.order, pool->p.dma_dir, + DMA_ATTR_SKIP_CPU_SYNC); + page->dma_addr = 0; } /* Return a page to the page allocator, cleaning up our state */ diff --git a/net/core/pktgen.c b/net/core/pktgen.c index 6ac919847ce6..f3f5a78cd062 100644 --- a/net/core/pktgen.c +++ b/net/core/pktgen.c @@ -158,6 +158,7 @@ #include <linux/etherdevice.h> #include <linux/kthread.h> #include <linux/prefetch.h> +#include <linux/mmzone.h> #include <net/net_namespace.h> #include <net/checksum.h> #include <net/ipv6.h> @@ -3625,7 +3626,7 @@ static int pktgen_add_device(struct pktgen_thread *t, const char *ifname) pkt_dev->svlan_cfi = 0; pkt_dev->svlan_id = 0xffff; pkt_dev->burst = 1; - pkt_dev->node = -1; + pkt_dev->node = NUMA_NO_NODE; err = pktgen_setup_dev(t->net, pkt_dev, ifname); if (err) diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c index 5ea1bed08ede..a51cab95ba64 100644 --- a/net/core/rtnetlink.c +++ b/net/core/rtnetlink.c @@ -46,7 +46,6 @@ #include <linux/inet.h> #include <linux/netdevice.h> -#include <net/switchdev.h> #include <net/ip.h> #include <net/protocol.h> #include <net/arp.h> @@ -1146,22 +1145,17 @@ static int rtnl_phys_port_name_fill(struct sk_buff *skb, struct net_device *dev) static int rtnl_phys_switch_id_fill(struct sk_buff *skb, struct net_device *dev) { + struct netdev_phys_item_id ppid = { }; int err; - struct switchdev_attr attr = { - .orig_dev = dev, - .id = SWITCHDEV_ATTR_ID_PORT_PARENT_ID, - .flags = SWITCHDEV_F_NO_RECURSE, - }; - err = switchdev_port_attr_get(dev, &attr); + err = dev_get_port_parent_id(dev, &ppid, false); if (err) { if (err == -EOPNOTSUPP) return 0; return err; } - if (nla_put(skb, IFLA_PHYS_SWITCH_ID, attr.u.ppid.id_len, - attr.u.ppid.id)) + if (nla_put(skb, IFLA_PHYS_SWITCH_ID, ppid.id_len, ppid.id)) return -EMSGSIZE; return 0; @@ -3242,6 +3236,53 @@ static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh, return ret; } +static int rtnl_valid_getlink_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifinfomsg *ifm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid header for get link"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*ifm), tb, IFLA_MAX, ifla_policy, + extack); + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change) { + NL_SET_ERR_MSG(extack, "Invalid values in header for get link request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(*ifm), tb, IFLA_MAX, ifla_policy, + extack); + if (err) + return err; + + for (i = 0; i <= IFLA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case IFLA_IFNAME: + case IFLA_EXT_MASK: + case IFLA_TARGET_NETNSID: + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in get link request"); + return -EINVAL; + } + } + + return 0; +} + static int rtnl_getlink(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -3256,7 +3297,7 @@ static int rtnl_getlink(struct sk_buff *skb, struct nlmsghdr *nlh, int err; u32 ext_filter_mask = 0; - err = nlmsg_parse(nlh, sizeof(*ifm), tb, IFLA_MAX, ifla_policy, extack); + err = rtnl_valid_getlink_req(skb, nlh, tb, extack); if (err < 0) return err; @@ -3639,7 +3680,7 @@ static int rtnl_fdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, const struct net_device_ops *ops = br_dev->netdev_ops; err = ops->ndo_fdb_add(ndm, tb, dev, addr, vid, - nlh->nlmsg_flags); + nlh->nlmsg_flags, extack); if (err) goto out; else @@ -3651,7 +3692,8 @@ static int rtnl_fdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, if (dev->netdev_ops->ndo_fdb_add) err = dev->netdev_ops->ndo_fdb_add(ndm, tb, dev, addr, vid, - nlh->nlmsg_flags); + nlh->nlmsg_flags, + extack); else err = ndo_dflt_fdb_add(ndm, tb, dev, addr, vid, nlh->nlmsg_flags); @@ -4901,6 +4943,40 @@ static size_t if_nlmsg_stats_size(const struct net_device *dev, return size; } +static int rtnl_valid_stats_req(const struct nlmsghdr *nlh, bool strict_check, + bool is_dump, struct netlink_ext_ack *extack) +{ + struct if_stats_msg *ifsm; + + if (nlh->nlmsg_len < sizeof(*ifsm)) { + NL_SET_ERR_MSG(extack, "Invalid header for stats dump"); + return -EINVAL; + } + + if (!strict_check) + return 0; + + ifsm = nlmsg_data(nlh); + + /* only requests using strict checks can pass data to influence + * the dump. The legacy exception is filter_mask. + */ + if (ifsm->pad1 || ifsm->pad2 || (is_dump && ifsm->ifindex)) { + NL_SET_ERR_MSG(extack, "Invalid values in header for stats dump request"); + return -EINVAL; + } + if (nlmsg_attrlen(nlh, sizeof(*ifsm))) { + NL_SET_ERR_MSG(extack, "Invalid attributes after stats header"); + return -EINVAL; + } + if (ifsm->filter_mask >= IFLA_STATS_FILTER_BIT(IFLA_STATS_MAX + 1)) { + NL_SET_ERR_MSG(extack, "Invalid stats requested through filter mask"); + return -EINVAL; + } + + return 0; +} + static int rtnl_stats_get(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -4912,8 +4988,10 @@ static int rtnl_stats_get(struct sk_buff *skb, struct nlmsghdr *nlh, u32 filter_mask; int err; - if (nlmsg_len(nlh) < sizeof(*ifsm)) - return -EINVAL; + err = rtnl_valid_stats_req(nlh, netlink_strict_get_check(skb), + false, extack); + if (err) + return err; ifsm = nlmsg_data(nlh); if (ifsm->ifindex > 0) @@ -4965,27 +5043,11 @@ static int rtnl_stats_dump(struct sk_buff *skb, struct netlink_callback *cb) cb->seq = net->dev_base_seq; - if (nlmsg_len(cb->nlh) < sizeof(*ifsm)) { - NL_SET_ERR_MSG(extack, "Invalid header for stats dump"); - return -EINVAL; - } + err = rtnl_valid_stats_req(cb->nlh, cb->strict_check, true, extack); + if (err) + return err; ifsm = nlmsg_data(cb->nlh); - - /* only requests using strict checks can pass data to influence - * the dump. The legacy exception is filter_mask. - */ - if (cb->strict_check) { - if (ifsm->pad1 || ifsm->pad2 || ifsm->ifindex) { - NL_SET_ERR_MSG(extack, "Invalid values in header for stats dump request"); - return -EINVAL; - } - if (nlmsg_attrlen(cb->nlh, sizeof(*ifsm))) { - NL_SET_ERR_MSG(extack, "Invalid attributes after stats header"); - return -EINVAL; - } - } - filter_mask = ifsm->filter_mask; if (!filter_mask) { NL_SET_ERR_MSG(extack, "Filter mask must be set for stats dump"); diff --git a/net/core/scm.c b/net/core/scm.c index b1ff8a441748..52ef219cf6df 100644 --- a/net/core/scm.c +++ b/net/core/scm.c @@ -29,6 +29,7 @@ #include <linux/pid.h> #include <linux/nsproxy.h> #include <linux/slab.h> +#include <linux/errqueue.h> #include <linux/uaccess.h> @@ -252,6 +253,32 @@ out: } EXPORT_SYMBOL(put_cmsg); +void put_cmsg_scm_timestamping64(struct msghdr *msg, struct scm_timestamping_internal *tss_internal) +{ + struct scm_timestamping64 tss; + int i; + + for (i = 0; i < ARRAY_SIZE(tss.ts); i++) { + tss.ts[i].tv_sec = tss_internal->ts[i].tv_sec; + tss.ts[i].tv_nsec = tss_internal->ts[i].tv_nsec; + } + + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPING_NEW, sizeof(tss), &tss); +} +EXPORT_SYMBOL(put_cmsg_scm_timestamping64); + +void put_cmsg_scm_timestamping(struct msghdr *msg, struct scm_timestamping_internal *tss_internal) +{ + struct scm_timestamping tss; + int i; + + for (i = 0; i < ARRAY_SIZE(tss.ts); i++) + tss.ts[i] = timespec64_to_timespec(tss_internal->ts[i]); + + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPING_OLD, sizeof(tss), &tss); +} +EXPORT_SYMBOL(put_cmsg_scm_timestamping); + void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm) { struct cmsghdr __user *cm diff --git a/net/core/skbuff.c b/net/core/skbuff.c index 26d848484912..2415d9cb9b89 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -356,6 +356,8 @@ static void *__netdev_alloc_frag(unsigned int fragsz, gfp_t gfp_mask) */ void *netdev_alloc_frag(unsigned int fragsz) { + fragsz = SKB_DATA_ALIGN(fragsz); + return __netdev_alloc_frag(fragsz, GFP_ATOMIC); } EXPORT_SYMBOL(netdev_alloc_frag); @@ -369,6 +371,8 @@ static void *__napi_alloc_frag(unsigned int fragsz, gfp_t gfp_mask) void *napi_alloc_frag(unsigned int fragsz) { + fragsz = SKB_DATA_ALIGN(fragsz); + return __napi_alloc_frag(fragsz, GFP_ATOMIC); } EXPORT_SYMBOL(napi_alloc_frag); diff --git a/net/core/skmsg.c b/net/core/skmsg.c index d6d5c20d7044..cc94d921476c 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -78,11 +78,9 @@ int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, { int i = src->sg.start; struct scatterlist *sge = sk_msg_elem(src, i); + struct scatterlist *sgd = NULL; u32 sge_len, sge_off; - if (sk_msg_full(dst)) - return -ENOSPC; - while (off) { if (sge->length > off) break; @@ -94,16 +92,27 @@ int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, } while (len) { - if (sk_msg_full(dst)) - return -ENOSPC; - sge_len = sge->length - off; - sge_off = sge->offset + off; if (sge_len > len) sge_len = len; + + if (dst->sg.end) + sgd = sk_msg_elem(dst, dst->sg.end - 1); + + if (sgd && + (sg_page(sge) == sg_page(sgd)) && + (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { + sgd->length += sge_len; + dst->sg.size += sge_len; + } else if (!sk_msg_full(dst)) { + sge_off = sge->offset + off; + sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); + } else { + return -ENOSPC; + } + off = 0; len -= sge_len; - sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); sk_mem_charge(sk, sge_len); sk_msg_iter_var_next(i); if (i == src->sg.end && len) @@ -545,8 +554,8 @@ static void sk_psock_destroy_deferred(struct work_struct *gc) struct sk_psock *psock = container_of(gc, struct sk_psock, gc); /* No sk_callback_lock since already detached. */ - if (psock->parser.enabled) - strp_done(&psock->parser.strp); + strp_stop(&psock->parser.strp); + strp_done(&psock->parser.strp); cancel_work_sync(&psock->work); diff --git a/net/core/sock.c b/net/core/sock.c index 6aa2e7e0b4fb..782343bb925b 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -335,14 +335,68 @@ int __sk_backlog_rcv(struct sock *sk, struct sk_buff *skb) } EXPORT_SYMBOL(__sk_backlog_rcv); -static int sock_set_timeout(long *timeo_p, char __user *optval, int optlen) +static int sock_get_timeout(long timeo, void *optval, bool old_timeval) { - struct timeval tv; + struct __kernel_sock_timeval tv; + int size; - if (optlen < sizeof(tv)) - return -EINVAL; - if (copy_from_user(&tv, optval, sizeof(tv))) - return -EFAULT; + if (timeo == MAX_SCHEDULE_TIMEOUT) { + tv.tv_sec = 0; + tv.tv_usec = 0; + } else { + tv.tv_sec = timeo / HZ; + tv.tv_usec = ((timeo % HZ) * USEC_PER_SEC) / HZ; + } + + if (in_compat_syscall() && !COMPAT_USE_64BIT_TIME) { + struct old_timeval32 tv32 = { tv.tv_sec, tv.tv_usec }; + *(struct old_timeval32 *)optval = tv32; + return sizeof(tv32); + } + + if (old_timeval) { + struct __kernel_old_timeval old_tv; + old_tv.tv_sec = tv.tv_sec; + old_tv.tv_usec = tv.tv_usec; + *(struct __kernel_old_timeval *)optval = old_tv; + size = sizeof(old_tv); + } else { + *(struct __kernel_sock_timeval *)optval = tv; + size = sizeof(tv); + } + + return size; +} + +static int sock_set_timeout(long *timeo_p, char __user *optval, int optlen, bool old_timeval) +{ + struct __kernel_sock_timeval tv; + + if (in_compat_syscall() && !COMPAT_USE_64BIT_TIME) { + struct old_timeval32 tv32; + + if (optlen < sizeof(tv32)) + return -EINVAL; + + if (copy_from_user(&tv32, optval, sizeof(tv32))) + return -EFAULT; + tv.tv_sec = tv32.tv_sec; + tv.tv_usec = tv32.tv_usec; + } else if (old_timeval) { + struct __kernel_old_timeval old_tv; + + if (optlen < sizeof(old_tv)) + return -EINVAL; + if (copy_from_user(&old_tv, optval, sizeof(old_tv))) + return -EFAULT; + tv.tv_sec = old_tv.tv_sec; + tv.tv_usec = old_tv.tv_usec; + } else { + if (optlen < sizeof(tv)) + return -EINVAL; + if (copy_from_user(&tv, optval, sizeof(tv))) + return -EFAULT; + } if (tv.tv_usec < 0 || tv.tv_usec >= USEC_PER_SEC) return -EDOM; @@ -360,8 +414,8 @@ static int sock_set_timeout(long *timeo_p, char __user *optval, int optlen) *timeo_p = MAX_SCHEDULE_TIMEOUT; if (tv.tv_sec == 0 && tv.tv_usec == 0) return 0; - if (tv.tv_sec < (MAX_SCHEDULE_TIMEOUT/HZ - 1)) - *timeo_p = tv.tv_sec * HZ + DIV_ROUND_UP(tv.tv_usec, USEC_PER_SEC / HZ); + if (tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) + *timeo_p = tv.tv_sec * HZ + DIV_ROUND_UP((unsigned long)tv.tv_usec, USEC_PER_SEC / HZ); return 0; } @@ -520,14 +574,11 @@ struct dst_entry *sk_dst_check(struct sock *sk, u32 cookie) } EXPORT_SYMBOL(sk_dst_check); -static int sock_setbindtodevice(struct sock *sk, char __user *optval, - int optlen) +static int sock_setbindtodevice_locked(struct sock *sk, int ifindex) { int ret = -ENOPROTOOPT; #ifdef CONFIG_NETDEVICES struct net *net = sock_net(sk); - char devname[IFNAMSIZ]; - int index; /* Sorry... */ ret = -EPERM; @@ -535,6 +586,32 @@ static int sock_setbindtodevice(struct sock *sk, char __user *optval, goto out; ret = -EINVAL; + if (ifindex < 0) + goto out; + + sk->sk_bound_dev_if = ifindex; + if (sk->sk_prot->rehash) + sk->sk_prot->rehash(sk); + sk_dst_reset(sk); + + ret = 0; + +out: +#endif + + return ret; +} + +static int sock_setbindtodevice(struct sock *sk, char __user *optval, + int optlen) +{ + int ret = -ENOPROTOOPT; +#ifdef CONFIG_NETDEVICES + struct net *net = sock_net(sk); + char devname[IFNAMSIZ]; + int index; + + ret = -EINVAL; if (optlen < 0) goto out; @@ -566,14 +643,9 @@ static int sock_setbindtodevice(struct sock *sk, char __user *optval, } lock_sock(sk); - sk->sk_bound_dev_if = index; - if (sk->sk_prot->rehash) - sk->sk_prot->rehash(sk); - sk_dst_reset(sk); + ret = sock_setbindtodevice_locked(sk, index); release_sock(sk); - ret = 0; - out: #endif @@ -713,6 +785,10 @@ int sock_setsockopt(struct socket *sock, int level, int optname, */ val = min_t(u32, val, sysctl_wmem_max); set_sndbuf: + /* Ensure val * 2 fits into an int, to prevent max_t() + * from treating it as a negative value. + */ + val = min_t(int, val, INT_MAX / 2); sk->sk_userlocks |= SOCK_SNDBUF_LOCK; sk->sk_sndbuf = max_t(int, val * 2, SOCK_MIN_SNDBUF); /* Wake up sending tasks if we upped the value. */ @@ -724,6 +800,12 @@ set_sndbuf: ret = -EPERM; break; } + + /* No negative values (to prevent underflow, as val will be + * multiplied by 2). + */ + if (val < 0) + val = 0; goto set_sndbuf; case SO_RCVBUF: @@ -734,6 +816,10 @@ set_sndbuf: */ val = min_t(u32, val, sysctl_rmem_max); set_rcvbuf: + /* Ensure val * 2 fits into an int, to prevent max_t() + * from treating it as a negative value. + */ + val = min_t(int, val, INT_MAX / 2); sk->sk_userlocks |= SOCK_RCVBUF_LOCK; /* * We double it on the way in to account for @@ -758,6 +844,12 @@ set_rcvbuf: ret = -EPERM; break; } + + /* No negative values (to prevent underflow, as val will be + * multiplied by 2). + */ + if (val < 0) + val = 0; goto set_rcvbuf; case SO_KEEPALIVE: @@ -815,10 +907,17 @@ set_rcvbuf: clear_bit(SOCK_PASSCRED, &sock->flags); break; - case SO_TIMESTAMP: - case SO_TIMESTAMPNS: + case SO_TIMESTAMP_OLD: + case SO_TIMESTAMP_NEW: + case SO_TIMESTAMPNS_OLD: + case SO_TIMESTAMPNS_NEW: if (valbool) { - if (optname == SO_TIMESTAMP) + if (optname == SO_TIMESTAMP_NEW || optname == SO_TIMESTAMPNS_NEW) + sock_set_flag(sk, SOCK_TSTAMP_NEW); + else + sock_reset_flag(sk, SOCK_TSTAMP_NEW); + + if (optname == SO_TIMESTAMP_OLD || optname == SO_TIMESTAMP_NEW) sock_reset_flag(sk, SOCK_RCVTSTAMPNS); else sock_set_flag(sk, SOCK_RCVTSTAMPNS); @@ -827,10 +926,14 @@ set_rcvbuf: } else { sock_reset_flag(sk, SOCK_RCVTSTAMP); sock_reset_flag(sk, SOCK_RCVTSTAMPNS); + sock_reset_flag(sk, SOCK_TSTAMP_NEW); } break; - case SO_TIMESTAMPING: + case SO_TIMESTAMPING_NEW: + sock_set_flag(sk, SOCK_TSTAMP_NEW); + /* fall through */ + case SO_TIMESTAMPING_OLD: if (val & ~SOF_TIMESTAMPING_MASK) { ret = -EINVAL; break; @@ -861,9 +964,13 @@ set_rcvbuf: if (val & SOF_TIMESTAMPING_RX_SOFTWARE) sock_enable_timestamp(sk, SOCK_TIMESTAMPING_RX_SOFTWARE); - else + else { + if (optname == SO_TIMESTAMPING_NEW) + sock_reset_flag(sk, SOCK_TSTAMP_NEW); + sock_disable_timestamp(sk, (1UL << SOCK_TIMESTAMPING_RX_SOFTWARE)); + } break; case SO_RCVLOWAT: @@ -875,12 +982,14 @@ set_rcvbuf: sk->sk_rcvlowat = val ? : 1; break; - case SO_RCVTIMEO: - ret = sock_set_timeout(&sk->sk_rcvtimeo, optval, optlen); + case SO_RCVTIMEO_OLD: + case SO_RCVTIMEO_NEW: + ret = sock_set_timeout(&sk->sk_rcvtimeo, optval, optlen, optname == SO_RCVTIMEO_OLD); break; - case SO_SNDTIMEO: - ret = sock_set_timeout(&sk->sk_sndtimeo, optval, optlen); + case SO_SNDTIMEO_OLD: + case SO_SNDTIMEO_NEW: + ret = sock_set_timeout(&sk->sk_sndtimeo, optval, optlen, optname == SO_SNDTIMEO_OLD); break; case SO_ATTACH_FILTER: @@ -999,15 +1108,23 @@ set_rcvbuf: #endif case SO_MAX_PACING_RATE: - if (val != ~0U) + { + unsigned long ulval = (val == ~0U) ? ~0UL : val; + + if (sizeof(ulval) != sizeof(val) && + optlen >= sizeof(ulval) && + get_user(ulval, (unsigned long __user *)optval)) { + ret = -EFAULT; + break; + } + if (ulval != ~0UL) cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); - sk->sk_max_pacing_rate = (val == ~0U) ? ~0UL : val; - sk->sk_pacing_rate = min(sk->sk_pacing_rate, - sk->sk_max_pacing_rate); + sk->sk_max_pacing_rate = ulval; + sk->sk_pacing_rate = min(sk->sk_pacing_rate, ulval); break; - + } case SO_INCOMING_CPU: sk->sk_incoming_cpu = val; break; @@ -1055,6 +1172,10 @@ set_rcvbuf: } break; + case SO_BINDTOIFINDEX: + ret = sock_setbindtodevice_locked(sk, val); + break; + default: ret = -ENOPROTOOPT; break; @@ -1098,8 +1219,11 @@ int sock_getsockopt(struct socket *sock, int level, int optname, union { int val; u64 val64; + unsigned long ulval; struct linger ling; - struct timeval tm; + struct old_timeval32 tm32; + struct __kernel_old_timeval tm; + struct __kernel_sock_timeval stm; struct sock_txtime txtime; } v; @@ -1186,39 +1310,36 @@ int sock_getsockopt(struct socket *sock, int level, int optname, sock_warn_obsolete_bsdism("getsockopt"); break; - case SO_TIMESTAMP: + case SO_TIMESTAMP_OLD: v.val = sock_flag(sk, SOCK_RCVTSTAMP) && + !sock_flag(sk, SOCK_TSTAMP_NEW) && !sock_flag(sk, SOCK_RCVTSTAMPNS); break; - case SO_TIMESTAMPNS: - v.val = sock_flag(sk, SOCK_RCVTSTAMPNS); + case SO_TIMESTAMPNS_OLD: + v.val = sock_flag(sk, SOCK_RCVTSTAMPNS) && !sock_flag(sk, SOCK_TSTAMP_NEW); break; - case SO_TIMESTAMPING: + case SO_TIMESTAMP_NEW: + v.val = sock_flag(sk, SOCK_RCVTSTAMP) && sock_flag(sk, SOCK_TSTAMP_NEW); + break; + + case SO_TIMESTAMPNS_NEW: + v.val = sock_flag(sk, SOCK_RCVTSTAMPNS) && sock_flag(sk, SOCK_TSTAMP_NEW); + break; + + case SO_TIMESTAMPING_OLD: v.val = sk->sk_tsflags; break; - case SO_RCVTIMEO: - lv = sizeof(struct timeval); - if (sk->sk_rcvtimeo == MAX_SCHEDULE_TIMEOUT) { - v.tm.tv_sec = 0; - v.tm.tv_usec = 0; - } else { - v.tm.tv_sec = sk->sk_rcvtimeo / HZ; - v.tm.tv_usec = ((sk->sk_rcvtimeo % HZ) * USEC_PER_SEC) / HZ; - } + case SO_RCVTIMEO_OLD: + case SO_RCVTIMEO_NEW: + lv = sock_get_timeout(sk->sk_rcvtimeo, &v, SO_RCVTIMEO_OLD == optname); break; - case SO_SNDTIMEO: - lv = sizeof(struct timeval); - if (sk->sk_sndtimeo == MAX_SCHEDULE_TIMEOUT) { - v.tm.tv_sec = 0; - v.tm.tv_usec = 0; - } else { - v.tm.tv_sec = sk->sk_sndtimeo / HZ; - v.tm.tv_usec = ((sk->sk_sndtimeo % HZ) * USEC_PER_SEC) / HZ; - } + case SO_SNDTIMEO_OLD: + case SO_SNDTIMEO_NEW: + lv = sock_get_timeout(sk->sk_sndtimeo, &v, SO_SNDTIMEO_OLD == optname); break; case SO_RCVLOWAT: @@ -1344,8 +1465,13 @@ int sock_getsockopt(struct socket *sock, int level, int optname, #endif case SO_MAX_PACING_RATE: - /* 32bit version */ - v.val = min_t(unsigned long, sk->sk_max_pacing_rate, ~0U); + if (sizeof(v.ulval) != sizeof(v.val) && len >= sizeof(v.ulval)) { + lv = sizeof(v.ulval); + v.ulval = sk->sk_max_pacing_rate; + } else { + /* 32bit version */ + v.val = min_t(unsigned long, sk->sk_max_pacing_rate, ~0U); + } break; case SO_INCOMING_CPU: @@ -1399,6 +1525,10 @@ int sock_getsockopt(struct socket *sock, int level, int optname, SOF_TXTIME_REPORT_ERRORS : 0; break; + case SO_BINDTOIFINDEX: + v.val = sk->sk_bound_dev_if; + break; + default: /* We implement the SO_SNDLOWAT etc to not be settable * (1003.1g 7). @@ -1726,7 +1856,6 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) newsk->sk_err_soft = 0; newsk->sk_priority = 0; newsk->sk_incoming_cpu = raw_smp_processor_id(); - atomic64_set(&newsk->sk_cookie, 0); if (likely(newsk->sk_net_refcnt)) sock_inuse_add(sock_net(newsk), 1); @@ -1750,7 +1879,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) */ sk_refcnt_debug_inc(newsk); sk_set_socket(newsk, NULL); - newsk->sk_wq = NULL; + RCU_INIT_POINTER(newsk->sk_wq, NULL); if (newsk->sk_prot->sockets_allocated) sk_sockets_allocated_inc(newsk); @@ -2122,7 +2251,7 @@ int __sock_cmsg_send(struct sock *sk, struct msghdr *msg, struct cmsghdr *cmsg, return -EINVAL; sockc->mark = *(u32 *)CMSG_DATA(cmsg); break; - case SO_TIMESTAMPING: + case SO_TIMESTAMPING_OLD: if (cmsg->cmsg_len != CMSG_LEN(sizeof(u32))) return -EINVAL; @@ -2380,7 +2509,7 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) } if (sk_has_memory_pressure(sk)) { - int alloc; + u64 alloc; if (!sk_under_memory_pressure(sk)) return 1; @@ -2713,11 +2842,11 @@ void sock_init_data(struct socket *sock, struct sock *sk) if (sock) { sk->sk_type = sock->type; - sk->sk_wq = sock->wq; + RCU_INIT_POINTER(sk->sk_wq, sock->wq); sock->sk = sk; sk->sk_uid = SOCK_INODE(sock)->i_uid; } else { - sk->sk_wq = NULL; + RCU_INIT_POINTER(sk->sk_wq, NULL); sk->sk_uid = make_kuid(sock_net(sk)->user_ns, 0); } diff --git a/net/core/sysctl_net_core.c b/net/core/sysctl_net_core.c index d67ec17f2cc8..84bf2861f45f 100644 --- a/net/core/sysctl_net_core.c +++ b/net/core/sysctl_net_core.c @@ -36,6 +36,15 @@ static int net_msg_warn; /* Unused, but still a sysctl */ int sysctl_fb_tunnels_only_for_init_net __read_mostly = 0; EXPORT_SYMBOL(sysctl_fb_tunnels_only_for_init_net); +/* 0 - Keep current behavior: + * IPv4: inherit all current settings from init_net + * IPv6: reset all settings to default + * 1 - Both inherit all current settings from init_net + * 2 - Both reset all settings to default + */ +int sysctl_devconf_inherit_init_net __read_mostly; +EXPORT_SYMBOL(sysctl_devconf_inherit_init_net); + #ifdef CONFIG_RPS static int rps_sock_flow_sysctl(struct ctl_table *table, int write, void __user *buffer, size_t *lenp, loff_t *ppos) @@ -544,6 +553,15 @@ static struct ctl_table net_core_table[] = { .extra1 = &zero, .extra2 = &one, }, + { + .procname = "devconf_inherit_init_net", + .data = &sysctl_devconf_inherit_init_net, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &two, + }, { } }; diff --git a/net/dccp/ccid.h b/net/dccp/ccid.h index 6eb837a47b5c..baaaeb2b2c42 100644 --- a/net/dccp/ccid.h +++ b/net/dccp/ccid.h @@ -202,7 +202,7 @@ static inline void ccid_hc_tx_packet_recv(struct ccid *ccid, struct sock *sk, static inline int ccid_hc_tx_parse_options(struct ccid *ccid, struct sock *sk, u8 pkt, u8 opt, u8 *val, u8 len) { - if (ccid->ccid_ops->ccid_hc_tx_parse_options == NULL) + if (!ccid || !ccid->ccid_ops->ccid_hc_tx_parse_options) return 0; return ccid->ccid_ops->ccid_hc_tx_parse_options(sk, pkt, opt, val, len); } @@ -214,7 +214,7 @@ static inline int ccid_hc_tx_parse_options(struct ccid *ccid, struct sock *sk, static inline int ccid_hc_rx_parse_options(struct ccid *ccid, struct sock *sk, u8 pkt, u8 opt, u8 *val, u8 len) { - if (ccid->ccid_ops->ccid_hc_rx_parse_options == NULL) + if (!ccid || !ccid->ccid_ops->ccid_hc_rx_parse_options) return 0; return ccid->ccid_ops->ccid_hc_rx_parse_options(sk, pkt, opt, val, len); } diff --git a/net/dccp/input.c b/net/dccp/input.c index 85d6c879383d..8d03707abdac 100644 --- a/net/dccp/input.c +++ b/net/dccp/input.c @@ -480,7 +480,7 @@ static int dccp_rcv_request_sent_state_process(struct sock *sk, sk_wake_async(sk, SOCK_WAKE_IO, POLL_OUT); } - if (sk->sk_write_pending || icsk->icsk_ack.pingpong || + if (sk->sk_write_pending || inet_csk_in_pingpong_mode(sk) || icsk->icsk_accept_queue.rskq_defer_accept) { /* Save one ACK. Data will be ready after * several ticks, if write_pending is set. diff --git a/net/dccp/timer.c b/net/dccp/timer.c index 1501a20a94ca..74e138495d67 100644 --- a/net/dccp/timer.c +++ b/net/dccp/timer.c @@ -199,7 +199,7 @@ static void dccp_delack_timer(struct timer_list *t) icsk->icsk_ack.pending &= ~ICSK_ACK_TIMER; if (inet_csk_ack_scheduled(sk)) { - if (!icsk->icsk_ack.pingpong) { + if (!inet_csk_in_pingpong_mode(sk)) { /* Delayed ACK missed: inflate ATO. */ icsk->icsk_ack.ato = min(icsk->icsk_ack.ato << 1, icsk->icsk_rto); @@ -207,7 +207,7 @@ static void dccp_delack_timer(struct timer_list *t) /* Delayed ACK missed: leave pingpong mode and * deflate ATO. */ - icsk->icsk_ack.pingpong = 0; + inet_csk_exit_pingpong_mode(sk); icsk->icsk_ack.ato = TCP_ATO_MIN; } dccp_send_ack(sk); diff --git a/net/decnet/dn_dev.c b/net/decnet/dn_dev.c index d0b3e69c6b39..0962f9201baa 100644 --- a/net/decnet/dn_dev.c +++ b/net/decnet/dn_dev.c @@ -56,7 +56,7 @@ #include <net/dn_neigh.h> #include <net/dn_fib.h> -#define DN_IFREQ_SIZE (sizeof(struct ifreq) - sizeof(struct sockaddr) + sizeof(struct sockaddr_dn)) +#define DN_IFREQ_SIZE (offsetof(struct ifreq, ifr_ifru) + sizeof(struct sockaddr_dn)) static char dn_rt_all_end_mcast[ETH_ALEN] = {0xAB,0x00,0x00,0x04,0x00,0x00}; static char dn_rt_all_rt_mcast[ETH_ALEN] = {0xAB,0x00,0x00,0x03,0x00,0x00}; diff --git a/net/decnet/dn_fib.c b/net/decnet/dn_fib.c index f78fe58eafc8..6cd3737593a6 100644 --- a/net/decnet/dn_fib.c +++ b/net/decnet/dn_fib.c @@ -282,7 +282,7 @@ struct dn_fib_info *dn_fib_create_info(const struct rtmsg *r, struct nlattr *att (nhs = dn_fib_count_nhs(attrs[RTA_MULTIPATH])) == 0) goto err_inval; - fi = kzalloc(sizeof(*fi)+nhs*sizeof(struct dn_fib_nh), GFP_KERNEL); + fi = kzalloc(struct_size(fi, fib_nh, nhs), GFP_KERNEL); err = -ENOBUFS; if (fi == NULL) goto failure; diff --git a/net/dsa/Kconfig b/net/dsa/Kconfig index 91e52973ee13..fab49132345f 100644 --- a/net/dsa/Kconfig +++ b/net/dsa/Kconfig @@ -6,7 +6,7 @@ config HAVE_NET_DSA config NET_DSA tristate "Distributed Switch Architecture" - depends on HAVE_NET_DSA && MAY_USE_DEVLINK + depends on HAVE_NET_DSA depends on BRIDGE || BRIDGE=n select NET_SWITCHDEV select PHYLINK diff --git a/net/dsa/dsa.c b/net/dsa/dsa.c index aee909bcddc4..36de4f2a3366 100644 --- a/net/dsa/dsa.c +++ b/net/dsa/dsa.c @@ -57,6 +57,7 @@ const struct dsa_device_ops *dsa_device_ops[DSA_TAG_LAST] = { #endif #ifdef CONFIG_NET_DSA_TAG_KSZ9477 [DSA_TAG_PROTO_KSZ9477] = &ksz9477_netdev_ops, + [DSA_TAG_PROTO_KSZ9893] = &ksz9893_netdev_ops, #endif #ifdef CONFIG_NET_DSA_TAG_LAN9303 [DSA_TAG_PROTO_LAN9303] = &lan9303_netdev_ops, @@ -93,6 +94,7 @@ const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops) #endif #ifdef CONFIG_NET_DSA_TAG_KSZ9477 [DSA_TAG_PROTO_KSZ9477] = "ksz9477", + [DSA_TAG_PROTO_KSZ9893] = "ksz9893", #endif #ifdef CONFIG_NET_DSA_TAG_LAN9303 [DSA_TAG_PROTO_LAN9303] = "lan9303", diff --git a/net/dsa/dsa2.c b/net/dsa/dsa2.c index a1917025e155..c00ee464afc7 100644 --- a/net/dsa/dsa2.c +++ b/net/dsa/dsa2.c @@ -612,8 +612,8 @@ static int dsa_switch_parse_ports_of(struct dsa_switch *ds, { struct device_node *ports, *port; struct dsa_port *dp; + int err = 0; u32 reg; - int err; ports = of_get_child_by_name(dn, "ports"); if (!ports) { @@ -624,19 +624,23 @@ static int dsa_switch_parse_ports_of(struct dsa_switch *ds, for_each_available_child_of_node(ports, port) { err = of_property_read_u32(port, "reg", ®); if (err) - return err; + goto out_put_node; - if (reg >= ds->num_ports) - return -EINVAL; + if (reg >= ds->num_ports) { + err = -EINVAL; + goto out_put_node; + } dp = &ds->ports[reg]; err = dsa_port_parse_of(dp, port); if (err) - return err; + goto out_put_node; } - return 0; +out_put_node: + of_node_put(ports); + return err; } static int dsa_switch_parse_member_of(struct dsa_switch *ds, @@ -767,11 +771,10 @@ static int dsa_switch_probe(struct dsa_switch *ds) struct dsa_switch *dsa_switch_alloc(struct device *dev, size_t n) { - size_t size = sizeof(struct dsa_switch) + n * sizeof(struct dsa_port); struct dsa_switch *ds; int i; - ds = devm_kzalloc(dev, size, GFP_KERNEL); + ds = devm_kzalloc(dev, struct_size(ds, ports, n), GFP_KERNEL); if (!ds) return NULL; diff --git a/net/dsa/dsa_priv.h b/net/dsa/dsa_priv.h index 026a05774bf7..093b7d145eb1 100644 --- a/net/dsa/dsa_priv.h +++ b/net/dsa/dsa_priv.h @@ -103,7 +103,8 @@ static inline void dsa_legacy_unregister(void) { } int dsa_legacy_fdb_add(struct ndmsg *ndm, struct nlattr *tb[], struct net_device *dev, const unsigned char *addr, u16 vid, - u16 flags); + u16 flags, + struct netlink_ext_ack *extack); int dsa_legacy_fdb_del(struct ndmsg *ndm, struct nlattr *tb[], struct net_device *dev, const unsigned char *addr, u16 vid); @@ -142,7 +143,7 @@ static inline struct net_device *dsa_master_find_slave(struct net_device *dev, int dsa_port_set_state(struct dsa_port *dp, u8 state, struct switchdev_trans *trans); int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy); -void dsa_port_disable(struct dsa_port *dp, struct phy_device *phy); +void dsa_port_disable(struct dsa_port *dp); int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br); void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br); int dsa_port_vlan_filtering(struct dsa_port *dp, bool vlan_filtering, @@ -159,6 +160,10 @@ int dsa_port_mdb_add(const struct dsa_port *dp, struct switchdev_trans *trans); int dsa_port_mdb_del(const struct dsa_port *dp, const struct switchdev_obj_port_mdb *mdb); +int dsa_port_pre_bridge_flags(const struct dsa_port *dp, unsigned long flags, + struct switchdev_trans *trans); +int dsa_port_bridge_flags(const struct dsa_port *dp, unsigned long flags, + struct switchdev_trans *trans); int dsa_port_vlan_add(struct dsa_port *dp, const struct switchdev_obj_port_vlan *vlan, struct switchdev_trans *trans); @@ -211,6 +216,7 @@ extern const struct dsa_device_ops gswip_netdev_ops; /* tag_ksz.c */ extern const struct dsa_device_ops ksz9477_netdev_ops; +extern const struct dsa_device_ops ksz9893_netdev_ops; /* tag_lan9303.c */ extern const struct dsa_device_ops lan9303_netdev_ops; diff --git a/net/dsa/master.c b/net/dsa/master.c index 71bb15f491c8..c58f33931be1 100644 --- a/net/dsa/master.c +++ b/net/dsa/master.c @@ -126,6 +126,17 @@ static void dsa_master_get_strings(struct net_device *dev, uint32_t stringset, } } +static int dsa_master_get_phys_port_name(struct net_device *dev, + char *name, size_t len) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + + if (snprintf(name, len, "p%d", cpu_dp->index) >= len) + return -EINVAL; + + return 0; +} + static int dsa_master_ethtool_setup(struct net_device *dev) { struct dsa_port *cpu_dp = dev->dsa_ptr; @@ -158,6 +169,38 @@ static void dsa_master_ethtool_teardown(struct net_device *dev) cpu_dp->orig_ethtool_ops = NULL; } +static int dsa_master_ndo_setup(struct net_device *dev) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_switch *ds = cpu_dp->ds; + struct net_device_ops *ops; + + if (dev->netdev_ops->ndo_get_phys_port_name) + return 0; + + ops = devm_kzalloc(ds->dev, sizeof(*ops), GFP_KERNEL); + if (!ops) + return -ENOMEM; + + cpu_dp->orig_ndo_ops = dev->netdev_ops; + if (cpu_dp->orig_ndo_ops) + memcpy(ops, cpu_dp->orig_ndo_ops, sizeof(*ops)); + + ops->ndo_get_phys_port_name = dsa_master_get_phys_port_name; + + dev->netdev_ops = ops; + + return 0; +} + +static void dsa_master_ndo_teardown(struct net_device *dev) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + + dev->netdev_ops = cpu_dp->orig_ndo_ops; + cpu_dp->orig_ndo_ops = NULL; +} + static ssize_t tagging_show(struct device *d, struct device_attribute *attr, char *buf) { @@ -205,6 +248,8 @@ static void dsa_master_reset_mtu(struct net_device *dev) rtnl_unlock(); } +static struct lock_class_key dsa_master_addr_list_lock_key; + int dsa_master_setup(struct net_device *dev, struct dsa_port *cpu_dp) { int ret; @@ -218,21 +263,34 @@ int dsa_master_setup(struct net_device *dev, struct dsa_port *cpu_dp) wmb(); dev->dsa_ptr = cpu_dp; + lockdep_set_class(&dev->addr_list_lock, + &dsa_master_addr_list_lock_key); ret = dsa_master_ethtool_setup(dev); if (ret) return ret; + ret = dsa_master_ndo_setup(dev); + if (ret) + goto out_err_ethtool_teardown; + ret = sysfs_create_group(&dev->dev.kobj, &dsa_group); if (ret) - dsa_master_ethtool_teardown(dev); + goto out_err_ndo_teardown; + + return ret; +out_err_ndo_teardown: + dsa_master_ndo_teardown(dev); +out_err_ethtool_teardown: + dsa_master_ethtool_teardown(dev); return ret; } void dsa_master_teardown(struct net_device *dev) { sysfs_remove_group(&dev->dev.kobj, &dsa_group); + dsa_master_ndo_teardown(dev); dsa_master_ethtool_teardown(dev); dsa_master_reset_mtu(dev); diff --git a/net/dsa/port.c b/net/dsa/port.c index 2d7e01b23572..caeef4c99dc0 100644 --- a/net/dsa/port.c +++ b/net/dsa/port.c @@ -69,7 +69,6 @@ static void dsa_port_set_state_now(struct dsa_port *dp, u8 state) int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy) { - u8 stp_state = dp->bridge_dev ? BR_STATE_BLOCKING : BR_STATE_FORWARDING; struct dsa_switch *ds = dp->ds; int port = dp->index; int err; @@ -80,20 +79,22 @@ int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy) return err; } - dsa_port_set_state_now(dp, stp_state); + if (!dp->bridge_dev) + dsa_port_set_state_now(dp, BR_STATE_FORWARDING); return 0; } -void dsa_port_disable(struct dsa_port *dp, struct phy_device *phy) +void dsa_port_disable(struct dsa_port *dp) { struct dsa_switch *ds = dp->ds; int port = dp->index; - dsa_port_set_state_now(dp, BR_STATE_DISABLED); + if (!dp->bridge_dev) + dsa_port_set_state_now(dp, BR_STATE_DISABLED); if (ds->ops->port_disable) - ds->ops->port_disable(ds, port, phy); + ds->ops->port_disable(ds, port); } int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br) @@ -105,16 +106,23 @@ int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br) }; int err; - /* Here the port is already bridged. Reflect the current configuration - * so that drivers can program their chips accordingly. + /* Set the flooding mode before joining the port in the switch */ + err = dsa_port_bridge_flags(dp, BR_FLOOD | BR_MCAST_FLOOD, NULL); + if (err) + return err; + + /* Here the interface is already bridged. Reflect the current + * configuration so that drivers can program their chips accordingly. */ dp->bridge_dev = br; err = dsa_port_notify(dp, DSA_NOTIFIER_BRIDGE_JOIN, &info); /* The bridging is rolled back on error */ - if (err) + if (err) { + dsa_port_bridge_flags(dp, 0, NULL); dp->bridge_dev = NULL; + } return err; } @@ -137,6 +145,9 @@ void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br) if (err) pr_err("DSA: failed to notify DSA_NOTIFIER_BRIDGE_LEAVE\n"); + /* Port is leaving the bridge, disable flooding */ + dsa_port_bridge_flags(dp, 0, NULL); + /* Port left the bridge, put in BR_STATE_DISABLED by the bridge layer, * so allow it to be in BR_STATE_FORWARDING to be kept functional */ @@ -177,6 +188,35 @@ int dsa_port_ageing_time(struct dsa_port *dp, clock_t ageing_clock, return dsa_port_notify(dp, DSA_NOTIFIER_AGEING_TIME, &info); } +int dsa_port_pre_bridge_flags(const struct dsa_port *dp, unsigned long flags, + struct switchdev_trans *trans) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_egress_floods || + (flags & ~(BR_FLOOD | BR_MCAST_FLOOD))) + return -EINVAL; + + return 0; +} + +int dsa_port_bridge_flags(const struct dsa_port *dp, unsigned long flags, + struct switchdev_trans *trans) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + int err = 0; + + if (switchdev_trans_ph_prepare(trans)) + return 0; + + if (ds->ops->port_egress_floods) + err = ds->ops->port_egress_floods(ds, port, flags & BR_FLOOD, + flags & BR_MCAST_FLOOD); + + return err; +} + int dsa_port_fdb_add(struct dsa_port *dp, const unsigned char *addr, u16 vid) { @@ -252,7 +292,10 @@ int dsa_port_vlan_add(struct dsa_port *dp, .vlan = vlan, }; - if (br_vlan_enabled(dp->bridge_dev)) + /* Can be called from dsa_slave_port_obj_add() or + * dsa_slave_vlan_rx_add_vid() + */ + if (!dp->bridge_dev || br_vlan_enabled(dp->bridge_dev)) return dsa_port_notify(dp, DSA_NOTIFIER_VLAN_ADD, &info); return 0; @@ -267,10 +310,13 @@ int dsa_port_vlan_del(struct dsa_port *dp, .vlan = vlan, }; - if (netif_is_bridge_master(vlan->obj.orig_dev)) + if (vlan->obj.orig_dev && netif_is_bridge_master(vlan->obj.orig_dev)) return -EOPNOTSUPP; - if (br_vlan_enabled(dp->bridge_dev)) + /* Can be called from dsa_slave_port_obj_del() or + * dsa_slave_vlan_rx_kill_vid() + */ + if (!dp->bridge_dev || br_vlan_enabled(dp->bridge_dev)) return dsa_port_notify(dp, DSA_NOTIFIER_VLAN_DEL, &info); return 0; @@ -291,6 +337,7 @@ static struct phy_device *dsa_port_get_phy_device(struct dsa_port *dp) return ERR_PTR(-EPROBE_DEFER); } + of_node_put(phy_dn); return phydev; } diff --git a/net/dsa/slave.c b/net/dsa/slave.c index a3fcc1d01615..093eef6f2599 100644 --- a/net/dsa/slave.c +++ b/net/dsa/slave.c @@ -122,7 +122,7 @@ static int dsa_slave_close(struct net_device *dev) phylink_stop(dp->pl); - dsa_port_disable(dp, dev->phydev); + dsa_port_disable(dp); dev_mc_unsync(master, dev); dev_uc_unsync(master, dev); @@ -140,11 +140,14 @@ static int dsa_slave_close(struct net_device *dev) static void dsa_slave_change_rx_flags(struct net_device *dev, int change) { struct net_device *master = dsa_slave_to_master(dev); - - if (change & IFF_ALLMULTI) - dev_set_allmulti(master, dev->flags & IFF_ALLMULTI ? 1 : -1); - if (change & IFF_PROMISC) - dev_set_promiscuity(master, dev->flags & IFF_PROMISC ? 1 : -1); + if (dev->flags & IFF_UP) { + if (change & IFF_ALLMULTI) + dev_set_allmulti(master, + dev->flags & IFF_ALLMULTI ? 1 : -1); + if (change & IFF_PROMISC) + dev_set_promiscuity(master, + dev->flags & IFF_PROMISC ? 1 : -1); + } } static void dsa_slave_set_rx_mode(struct net_device *dev) @@ -292,6 +295,13 @@ static int dsa_slave_port_attr_set(struct net_device *dev, case SWITCHDEV_ATTR_ID_BRIDGE_AGEING_TIME: ret = dsa_port_ageing_time(dp, attr->u.ageing_time, trans); break; + case SWITCHDEV_ATTR_ID_PORT_PRE_BRIDGE_FLAGS: + ret = dsa_port_pre_bridge_flags(dp, attr->u.brport_flags, + trans); + break; + case SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS: + ret = dsa_port_bridge_flags(dp, attr->u.brport_flags, trans); + break; default: ret = -EOPNOTSUPP; break; @@ -362,24 +372,15 @@ static int dsa_slave_port_obj_del(struct net_device *dev, return err; } -static int dsa_slave_port_attr_get(struct net_device *dev, - struct switchdev_attr *attr) +static int dsa_slave_get_port_parent_id(struct net_device *dev, + struct netdev_phys_item_id *ppid) { struct dsa_port *dp = dsa_slave_to_port(dev); struct dsa_switch *ds = dp->ds; struct dsa_switch_tree *dst = ds->dst; - switch (attr->id) { - case SWITCHDEV_ATTR_ID_PORT_PARENT_ID: - attr->u.ppid.id_len = sizeof(dst->index); - memcpy(&attr->u.ppid.id, &dst->index, attr->u.ppid.id_len); - break; - case SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS_SUPPORT: - attr->u.brport_flags_support = 0; - break; - default: - return -EOPNOTSUPP; - } + ppid->id_len = sizeof(dst->index); + memcpy(&ppid->id, &dst->index, ppid->id_len); return 0; } @@ -639,7 +640,7 @@ static int dsa_slave_set_eee(struct net_device *dev, struct ethtool_eee *e) int ret; /* Port's PHY and MAC both need to be EEE capable */ - if (!dev->phydev && !dp->pl) + if (!dev->phydev || !dp->pl) return -ENODEV; if (!ds->ops->set_mac_eee) @@ -659,7 +660,7 @@ static int dsa_slave_get_eee(struct net_device *dev, struct ethtool_eee *e) int ret; /* Port's PHY and MAC both need to be EEE capable */ - if (!dev->phydev && !dp->pl) + if (!dev->phydev || !dp->pl) return -ENODEV; if (!ds->ops->get_mac_eee) @@ -982,6 +983,75 @@ static int dsa_slave_get_ts_info(struct net_device *dev, return ds->ops->get_ts_info(ds, p->dp->index, ts); } +static int dsa_slave_vlan_rx_add_vid(struct net_device *dev, __be16 proto, + u16 vid) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan vlan = { + .vid_begin = vid, + .vid_end = vid, + /* This API only allows programming tagged, non-PVID VIDs */ + .flags = 0, + }; + struct switchdev_trans trans; + struct bridge_vlan_info info; + int ret; + + /* Check for a possible bridge VLAN entry now since there is no + * need to emulate the switchdev prepare + commit phase. + */ + if (dp->bridge_dev) { + /* br_vlan_get_info() returns -EINVAL or -ENOENT if the + * device, respectively the VID is not found, returning + * 0 means success, which is a failure for us here. + */ + ret = br_vlan_get_info(dp->bridge_dev, vid, &info); + if (ret == 0) + return -EBUSY; + } + + trans.ph_prepare = true; + ret = dsa_port_vlan_add(dp, &vlan, &trans); + if (ret == -EOPNOTSUPP) + return 0; + + trans.ph_prepare = false; + return dsa_port_vlan_add(dp, &vlan, &trans); +} + +static int dsa_slave_vlan_rx_kill_vid(struct net_device *dev, __be16 proto, + u16 vid) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan vlan = { + .vid_begin = vid, + .vid_end = vid, + /* This API only allows programming tagged, non-PVID VIDs */ + .flags = 0, + }; + struct bridge_vlan_info info; + int ret; + + /* Check for a possible bridge VLAN entry now since there is no + * need to emulate the switchdev prepare + commit phase. + */ + if (dp->bridge_dev) { + /* br_vlan_get_info() returns -EINVAL or -ENOENT if the + * device, respectively the VID is not found, returning + * 0 means success, which is a failure for us here. + */ + ret = br_vlan_get_info(dp->bridge_dev, vid, &info); + if (ret == 0) + return -EBUSY; + } + + ret = dsa_port_vlan_del(dp, &vlan); + if (ret == -EOPNOTSUPP) + ret = 0; + + return ret; +} + static const struct ethtool_ops dsa_slave_ethtool_ops = { .get_drvinfo = dsa_slave_get_drvinfo, .get_regs_len = dsa_slave_get_regs_len, @@ -1009,7 +1079,8 @@ static const struct ethtool_ops dsa_slave_ethtool_ops = { int dsa_legacy_fdb_add(struct ndmsg *ndm, struct nlattr *tb[], struct net_device *dev, const unsigned char *addr, u16 vid, - u16 flags) + u16 flags, + struct netlink_ext_ack *extack) { struct dsa_port *dp = dsa_slave_to_port(dev); @@ -1045,11 +1116,9 @@ static const struct net_device_ops dsa_slave_netdev_ops = { .ndo_get_phys_port_name = dsa_slave_get_phys_port_name, .ndo_setup_tc = dsa_slave_setup_tc, .ndo_get_stats64 = dsa_slave_get_stats64, -}; - -static const struct switchdev_ops dsa_slave_switchdev_ops = { - .switchdev_port_attr_get = dsa_slave_port_attr_get, - .switchdev_port_attr_set = dsa_slave_port_attr_set, + .ndo_get_port_parent_id = dsa_slave_get_port_parent_id, + .ndo_vlan_rx_add_vid = dsa_slave_vlan_rx_add_vid, + .ndo_vlan_rx_kill_vid = dsa_slave_vlan_rx_kill_vid, }; static struct device_type dsa_type = { @@ -1305,13 +1374,13 @@ int dsa_slave_create(struct dsa_port *port) if (slave_dev == NULL) return -ENOMEM; - slave_dev->features = master->vlan_features | NETIF_F_HW_TC; + slave_dev->features = master->vlan_features | NETIF_F_HW_TC | + NETIF_F_HW_VLAN_CTAG_FILTER; slave_dev->hw_features |= NETIF_F_HW_TC; slave_dev->ethtool_ops = &dsa_slave_ethtool_ops; eth_hw_addr_inherit(slave_dev, master); slave_dev->priv_flags |= IFF_NO_QUEUE; slave_dev->netdev_ops = &dsa_slave_netdev_ops; - slave_dev->switchdev_ops = &dsa_slave_switchdev_ops; slave_dev->min_mtu = 0; slave_dev->max_mtu = ETH_MAX_MTU; SET_NETDEV_DEVTYPE(slave_dev, &dsa_type); @@ -1406,20 +1475,66 @@ static int dsa_slave_changeupper(struct net_device *dev, return err; } +static int dsa_slave_upper_vlan_check(struct net_device *dev, + struct netdev_notifier_changeupper_info * + info) +{ + struct netlink_ext_ack *ext_ack; + struct net_device *slave; + struct dsa_port *dp; + + ext_ack = netdev_notifier_info_to_extack(&info->info); + + if (!is_vlan_dev(dev)) + return NOTIFY_DONE; + + slave = vlan_dev_real_dev(dev); + if (!dsa_slave_dev_check(slave)) + return NOTIFY_DONE; + + dp = dsa_slave_to_port(slave); + if (!dp->bridge_dev) + return NOTIFY_DONE; + + /* Deny enslaving a VLAN device into a VLAN-aware bridge */ + if (br_vlan_enabled(dp->bridge_dev) && + netif_is_bridge_master(info->upper_dev) && info->linking) { + NL_SET_ERR_MSG_MOD(ext_ack, + "Cannot enslave VLAN device into VLAN aware bridge"); + return notifier_from_errno(-EINVAL); + } + + return NOTIFY_DONE; +} + static int dsa_slave_netdevice_event(struct notifier_block *nb, unsigned long event, void *ptr) { struct net_device *dev = netdev_notifier_info_to_dev(ptr); - if (!dsa_slave_dev_check(dev)) - return NOTIFY_DONE; + if (event == NETDEV_CHANGEUPPER) { + if (!dsa_slave_dev_check(dev)) + return dsa_slave_upper_vlan_check(dev, ptr); - if (event == NETDEV_CHANGEUPPER) return dsa_slave_changeupper(dev, ptr); + } return NOTIFY_DONE; } +static int +dsa_slave_switchdev_port_attr_set_event(struct net_device *netdev, + struct switchdev_notifier_port_attr_info *port_attr_info) +{ + int err; + + err = dsa_slave_port_attr_set(netdev, port_attr_info->attr, + port_attr_info->trans); + + port_attr_info->handled = true; + return notifier_from_errno(err); +} + struct dsa_switchdev_event_work { struct work_struct work; struct switchdev_notifier_fdb_info fdb_info; @@ -1450,7 +1565,7 @@ static void dsa_slave_switchdev_event_work(struct work_struct *work) } fdb_info->offloaded = true; call_switchdev_notifiers(SWITCHDEV_FDB_OFFLOADED, dev, - &fdb_info->info); + &fdb_info->info, NULL); break; case SWITCHDEV_FDB_DEL_TO_DEVICE: @@ -1498,6 +1613,9 @@ static int dsa_slave_switchdev_event(struct notifier_block *unused, if (!dsa_slave_dev_check(dev)) return NOTIFY_DONE; + if (event == SWITCHDEV_PORT_ATTR_SET) + return dsa_slave_switchdev_port_attr_set_event(dev, ptr); + switchdev_work = kzalloc(sizeof(*switchdev_work), GFP_ATOMIC); if (!switchdev_work) return NOTIFY_BAD; @@ -1560,6 +1678,8 @@ static int dsa_slave_switchdev_blocking_event(struct notifier_block *unused, case SWITCHDEV_PORT_OBJ_ADD: /* fall through */ case SWITCHDEV_PORT_OBJ_DEL: return dsa_slave_switchdev_port_obj_event(event, dev, ptr); + case SWITCHDEV_PORT_ATTR_SET: + return dsa_slave_switchdev_port_attr_set_event(dev, ptr); } return NOTIFY_DONE; diff --git a/net/dsa/switch.c b/net/dsa/switch.c index 142b294d3446..e1fae969aa73 100644 --- a/net/dsa/switch.c +++ b/net/dsa/switch.c @@ -12,6 +12,7 @@ #include <linux/netdevice.h> #include <linux/notifier.h> +#include <linux/if_vlan.h> #include <net/switchdev.h> #include "dsa_priv.h" @@ -168,6 +169,43 @@ static int dsa_switch_mdb_del(struct dsa_switch *ds, return 0; } +static int dsa_port_vlan_device_check(struct net_device *vlan_dev, + int vlan_dev_vid, + void *arg) +{ + struct switchdev_obj_port_vlan *vlan = arg; + u16 vid; + + for (vid = vlan->vid_begin; vid <= vlan->vid_end; ++vid) { + if (vid == vlan_dev_vid) + return -EBUSY; + } + + return 0; +} + +static int dsa_port_vlan_check(struct dsa_switch *ds, int port, + const struct switchdev_obj_port_vlan *vlan) +{ + const struct dsa_port *dp = dsa_to_port(ds, port); + int err = 0; + + /* Device is not bridged, let it proceed with the VLAN device + * creation. + */ + if (!dp->bridge_dev) + return err; + + /* dsa_slave_vlan_rx_{add,kill}_vid() cannot use the prepare pharse and + * already checks whether there is an overlapping bridge VLAN entry + * with the same VID, so here we only need to check that if we are + * adding a bridge VLAN entry there is not an overlapping VLAN device + * claiming that VID. + */ + return vlan_for_each(dp->slave, dsa_port_vlan_device_check, + (void *)vlan); +} + static int dsa_switch_vlan_prepare_bitmap(struct dsa_switch *ds, const struct switchdev_obj_port_vlan *vlan, @@ -179,6 +217,10 @@ dsa_switch_vlan_prepare_bitmap(struct dsa_switch *ds, return -EOPNOTSUPP; for_each_set_bit(port, bitmap, ds->num_ports) { + err = dsa_port_vlan_check(ds, port, vlan); + if (err) + return err; + err = ds->ops->port_vlan_prepare(ds, port, vlan); if (err) return err; diff --git a/net/dsa/tag_dsa.c b/net/dsa/tag_dsa.c index 8b2f92e3f3a2..67ff3fae18d8 100644 --- a/net/dsa/tag_dsa.c +++ b/net/dsa/tag_dsa.c @@ -146,8 +146,17 @@ static struct sk_buff *dsa_rcv(struct sk_buff *skb, struct net_device *dev, return skb; } +static int dsa_tag_flow_dissect(const struct sk_buff *skb, __be16 *proto, + int *offset) +{ + *offset = 4; + *proto = ((__be16 *)skb->data)[1]; + return 0; +} + const struct dsa_device_ops dsa_netdev_ops = { .xmit = dsa_xmit, .rcv = dsa_rcv, + .flow_dissect = dsa_tag_flow_dissect, .overhead = DSA_HLEN, }; diff --git a/net/dsa/tag_edsa.c b/net/dsa/tag_edsa.c index f5b87ee5c94e..234585ec116e 100644 --- a/net/dsa/tag_edsa.c +++ b/net/dsa/tag_edsa.c @@ -165,8 +165,17 @@ static struct sk_buff *edsa_rcv(struct sk_buff *skb, struct net_device *dev, return skb; } +static int edsa_tag_flow_dissect(const struct sk_buff *skb, __be16 *proto, + int *offset) +{ + *offset = 8; + *proto = ((__be16 *)skb->data)[3]; + return 0; +} + const struct dsa_device_ops edsa_netdev_ops = { .xmit = edsa_xmit, .rcv = edsa_rcv, + .flow_dissect = edsa_tag_flow_dissect, .overhead = EDSA_HLEN, }; diff --git a/net/dsa/tag_ksz.c b/net/dsa/tag_ksz.c index da71b9e2af52..de246c93d3bb 100644 --- a/net/dsa/tag_ksz.c +++ b/net/dsa/tag_ksz.c @@ -16,6 +16,7 @@ /* Typically only one byte is used for tail tag. */ #define KSZ_EGRESS_TAG_LEN 1 +#define KSZ_INGRESS_TAG_LEN 1 static struct sk_buff *ksz_common_xmit(struct sk_buff *skb, struct net_device *dev, int len) @@ -67,6 +68,8 @@ static struct sk_buff *ksz_common_rcv(struct sk_buff *skb, pskb_trim_rcsum(skb, skb->len - len); + skb->offload_fwd_mark = true; + return skb; } @@ -139,3 +142,36 @@ const struct dsa_device_ops ksz9477_netdev_ops = { .rcv = ksz9477_rcv, .overhead = KSZ9477_INGRESS_TAG_LEN, }; + +#define KSZ9893_TAIL_TAG_OVERRIDE BIT(5) +#define KSZ9893_TAIL_TAG_LOOKUP BIT(6) + +static struct sk_buff *ksz9893_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct sk_buff *nskb; + u8 *addr; + u8 *tag; + + nskb = ksz_common_xmit(skb, dev, KSZ_INGRESS_TAG_LEN); + if (!nskb) + return NULL; + + /* Tag encoding */ + tag = skb_put(nskb, KSZ_INGRESS_TAG_LEN); + addr = skb_mac_header(nskb); + + *tag = BIT(dp->index); + + if (is_link_local_ether_addr(addr)) + *tag |= KSZ9893_TAIL_TAG_OVERRIDE; + + return nskb; +} + +const struct dsa_device_ops ksz9893_netdev_ops = { + .xmit = ksz9893_xmit, + .rcv = ksz9477_rcv, + .overhead = KSZ_INGRESS_TAG_LEN, +}; diff --git a/net/ethernet/eth.c b/net/ethernet/eth.c index 4c520110b04f..f7a3d7a171c7 100644 --- a/net/ethernet/eth.c +++ b/net/ethernet/eth.c @@ -265,6 +265,18 @@ void eth_header_cache_update(struct hh_cache *hh, EXPORT_SYMBOL(eth_header_cache_update); /** + * eth_header_parser_protocol - extract protocol from L2 header + * @skb: packet to extract protocol from + */ +__be16 eth_header_parse_protocol(const struct sk_buff *skb) +{ + const struct ethhdr *eth = eth_hdr(skb); + + return eth->h_proto; +} +EXPORT_SYMBOL(eth_header_parse_protocol); + +/** * eth_prepare_mac_addr_change - prepare for mac change * @dev: network device * @p: socket address @@ -346,6 +358,7 @@ const struct header_ops eth_header_ops ____cacheline_aligned = { .parse = eth_header_parse, .cache = eth_header_cache, .cache_update = eth_header_cache_update, + .parse_protocol = eth_header_parse_protocol, }; /** diff --git a/net/hsr/hsr_device.c b/net/hsr/hsr_device.c index b8cd43c9ed5b..a97bf326b231 100644 --- a/net/hsr/hsr_device.c +++ b/net/hsr/hsr_device.c @@ -94,9 +94,8 @@ static void hsr_check_announce(struct net_device *hsr_dev, && (old_operstate != IF_OPER_UP)) { /* Went up */ hsr->announce_count = 0; - hsr->announce_timer.expires = jiffies + - msecs_to_jiffies(HSR_ANNOUNCE_INTERVAL); - add_timer(&hsr->announce_timer); + mod_timer(&hsr->announce_timer, + jiffies + msecs_to_jiffies(HSR_ANNOUNCE_INTERVAL)); } if ((hsr_dev->operstate != IF_OPER_UP) && (old_operstate == IF_OPER_UP)) @@ -332,6 +331,7 @@ static void hsr_announce(struct timer_list *t) { struct hsr_priv *hsr; struct hsr_port *master; + unsigned long interval; hsr = from_timer(hsr, t, announce_timer); @@ -343,18 +343,16 @@ static void hsr_announce(struct timer_list *t) hsr->protVersion); hsr->announce_count++; - hsr->announce_timer.expires = jiffies + - msecs_to_jiffies(HSR_ANNOUNCE_INTERVAL); + interval = msecs_to_jiffies(HSR_ANNOUNCE_INTERVAL); } else { send_hsr_supervision_frame(master, HSR_TLV_LIFE_CHECK, hsr->protVersion); - hsr->announce_timer.expires = jiffies + - msecs_to_jiffies(HSR_LIFE_CHECK_INTERVAL); + interval = msecs_to_jiffies(HSR_LIFE_CHECK_INTERVAL); } if (is_admin_up(master->dev)) - add_timer(&hsr->announce_timer); + mod_timer(&hsr->announce_timer, jiffies + interval); rcu_read_unlock(); } @@ -486,7 +484,7 @@ int hsr_dev_finalize(struct net_device *hsr_dev, struct net_device *slave[2], res = hsr_add_port(hsr, hsr_dev, HSR_PT_MASTER); if (res) - return res; + goto err_add_port; res = register_netdevice(hsr_dev); if (res) @@ -506,6 +504,8 @@ int hsr_dev_finalize(struct net_device *hsr_dev, struct net_device *slave[2], fail: hsr_for_each_port(hsr, port) hsr_del_port(port); +err_add_port: + hsr_del_node(&hsr->self_node_db); return res; } diff --git a/net/hsr/hsr_framereg.c b/net/hsr/hsr_framereg.c index 286ceb41ac0c..9af16cb68f76 100644 --- a/net/hsr/hsr_framereg.c +++ b/net/hsr/hsr_framereg.c @@ -124,6 +124,18 @@ int hsr_create_self_node(struct list_head *self_node_db, return 0; } +void hsr_del_node(struct list_head *self_node_db) +{ + struct hsr_node *node; + + rcu_read_lock(); + node = list_first_or_null_rcu(self_node_db, struct hsr_node, mac_list); + rcu_read_unlock(); + if (node) { + list_del_rcu(&node->mac_list); + kfree(node); + } +} /* Allocate an hsr_node and add it to node_db. 'addr' is the node's AddressA; * seq_out is used to initialize filtering of outgoing duplicate frames diff --git a/net/hsr/hsr_framereg.h b/net/hsr/hsr_framereg.h index 370b45998121..531fd3dfcac1 100644 --- a/net/hsr/hsr_framereg.h +++ b/net/hsr/hsr_framereg.h @@ -16,6 +16,7 @@ struct hsr_node; +void hsr_del_node(struct list_head *self_node_db); struct hsr_node *hsr_add_node(struct list_head *node_db, unsigned char addr[], u16 seq_out); struct hsr_node *hsr_get_node(struct hsr_port *port, struct sk_buff *skb, diff --git a/net/ieee802154/6lowpan/reassembly.c b/net/ieee802154/6lowpan/reassembly.c index d14226ecfde4..4196bcd4105a 100644 --- a/net/ieee802154/6lowpan/reassembly.c +++ b/net/ieee802154/6lowpan/reassembly.c @@ -27,6 +27,7 @@ #include <net/6lowpan.h> #include <net/ipv6_frag.h> #include <net/inet_frag.h> +#include <net/ip.h> #include "6lowpan_i.h" @@ -34,8 +35,8 @@ static const char lowpan_frags_cache_name[] = "lowpan-frags"; static struct inet_frags lowpan_frags; -static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, - struct sk_buff *prev, struct net_device *ldev); +static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev, struct net_device *ldev); static void lowpan_frag_init(struct inet_frag_queue *q, const void *a) { @@ -88,9 +89,15 @@ fq_find(struct net *net, const struct lowpan_802154_cb *cb, static int lowpan_frag_queue(struct lowpan_frag_queue *fq, struct sk_buff *skb, u8 frag_type) { - struct sk_buff *prev, *next; + struct sk_buff *prev_tail; struct net_device *ldev; - int end, offset; + int end, offset, err; + + /* inet_frag_queue_* functions use skb->cb; see struct ipfrag_skb_cb + * in inet_fragment.c + */ + BUILD_BUG_ON(sizeof(struct lowpan_802154_cb) > sizeof(struct inet_skb_parm)); + BUILD_BUG_ON(sizeof(struct lowpan_802154_cb) > sizeof(struct inet6_skb_parm)); if (fq->q.flags & INET_FRAG_COMPLETE) goto err; @@ -117,38 +124,15 @@ static int lowpan_frag_queue(struct lowpan_frag_queue *fq, } } - /* Find out which fragments are in front and at the back of us - * in the chain of fragments so far. We must know where to put - * this fragment, right? - */ - prev = fq->q.fragments_tail; - if (!prev || - lowpan_802154_cb(prev)->d_offset < - lowpan_802154_cb(skb)->d_offset) { - next = NULL; - goto found; - } - prev = NULL; - for (next = fq->q.fragments; next != NULL; next = next->next) { - if (lowpan_802154_cb(next)->d_offset >= - lowpan_802154_cb(skb)->d_offset) - break; /* bingo! */ - prev = next; - } - -found: - /* Insert this fragment in the chain of fragments. */ - skb->next = next; - if (!next) - fq->q.fragments_tail = skb; - if (prev) - prev->next = skb; - else - fq->q.fragments = skb; - ldev = skb->dev; if (ldev) skb->dev = NULL; + barrier(); + + prev_tail = fq->q.fragments_tail; + err = inet_frag_queue_insert(&fq->q, skb, offset, end); + if (err) + goto err; fq->q.stamp = skb->tstamp; if (frag_type == LOWPAN_DISPATCH_FRAG1) @@ -163,10 +147,11 @@ found: unsigned long orefdst = skb->_skb_refdst; skb->_skb_refdst = 0UL; - res = lowpan_frag_reasm(fq, prev, ldev); + res = lowpan_frag_reasm(fq, skb, prev_tail, ldev); skb->_skb_refdst = orefdst; return res; } + skb_dst_drop(skb); return -1; err: @@ -175,97 +160,28 @@ err: } /* Check if this packet is complete. - * Returns NULL on failure by any reason, and pointer - * to current nexthdr field in reassembled frame. * * It is called with locked fq, and caller must check that * queue is eligible for reassembly i.e. it is not COMPLETE, * the last and the first frames arrived and all the bits are here. */ -static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *prev, - struct net_device *ldev) +static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *ldev) { - struct sk_buff *fp, *head = fq->q.fragments; - int sum_truesize; + void *reasm_data; inet_frag_kill(&fq->q); - /* Make the one we just received the head. */ - if (prev) { - head = prev->next; - fp = skb_clone(head, GFP_ATOMIC); - - if (!fp) - goto out_oom; - - fp->next = head->next; - if (!fp->next) - fq->q.fragments_tail = fp; - prev->next = fp; - - skb_morph(head, fq->q.fragments); - head->next = fq->q.fragments->next; - - consume_skb(fq->q.fragments); - fq->q.fragments = head; - } - - /* Head of list must not be cloned. */ - if (skb_unclone(head, GFP_ATOMIC)) + reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail); + if (!reasm_data) goto out_oom; + inet_frag_reasm_finish(&fq->q, skb, reasm_data); - /* If the first fragment is fragmented itself, we split - * it to two chunks: the first with data and paged part - * and the second, holding only fragments. - */ - if (skb_has_frag_list(head)) { - struct sk_buff *clone; - int i, plen = 0; - - clone = alloc_skb(0, GFP_ATOMIC); - if (!clone) - goto out_oom; - clone->next = head->next; - head->next = clone; - skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; - skb_frag_list_init(head); - for (i = 0; i < skb_shinfo(head)->nr_frags; i++) - plen += skb_frag_size(&skb_shinfo(head)->frags[i]); - clone->len = head->data_len - plen; - clone->data_len = clone->len; - head->data_len -= clone->len; - head->len -= clone->len; - add_frag_mem_limit(fq->q.net, clone->truesize); - } - - WARN_ON(head == NULL); - - sum_truesize = head->truesize; - for (fp = head->next; fp;) { - bool headstolen; - int delta; - struct sk_buff *next = fp->next; - - sum_truesize += fp->truesize; - if (skb_try_coalesce(head, fp, &headstolen, &delta)) { - kfree_skb_partial(fp, headstolen); - } else { - if (!skb_shinfo(head)->frag_list) - skb_shinfo(head)->frag_list = fp; - head->data_len += fp->len; - head->len += fp->len; - head->truesize += fp->truesize; - } - fp = next; - } - sub_frag_mem_limit(fq->q.net, sum_truesize); - - skb_mark_not_on_list(head); - head->dev = ldev; - head->tstamp = fq->q.stamp; - - fq->q.fragments = NULL; + skb->dev = ldev; + skb->tstamp = fq->q.stamp; + fq->q.rb_fragments = RB_ROOT; fq->q.fragments_tail = NULL; + fq->q.last_run_head = NULL; return 1; out_oom: diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 0dfb72c46671..eab3ebde981e 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -1385,6 +1385,15 @@ out: } EXPORT_SYMBOL(inet_gso_segment); +static struct sk_buff *ipip_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP4)) + return ERR_PTR(-EINVAL); + + return inet_gso_segment(skb, features); +} + INDIRECT_CALLABLE_DECLARE(struct sk_buff *tcp4_gro_receive(struct list_head *, struct sk_buff *)); INDIRECT_CALLABLE_DECLARE(struct sk_buff *udp4_gro_receive(struct list_head *, @@ -1861,7 +1870,7 @@ static struct packet_offload ip_packet_offload __read_mostly = { static const struct net_offload ipip_offload = { .callbacks = { - .gso_segment = inet_gso_segment, + .gso_segment = ipip_gso_segment, .gro_receive = ipip_gro_receive, .gro_complete = ipip_gro_complete, }, diff --git a/net/ipv4/cipso_ipv4.c b/net/ipv4/cipso_ipv4.c index 777fa3b7fb13..f0165c5f376b 100644 --- a/net/ipv4/cipso_ipv4.c +++ b/net/ipv4/cipso_ipv4.c @@ -667,7 +667,8 @@ static int cipso_v4_map_lvl_valid(const struct cipso_v4_doi *doi_def, u8 level) case CIPSO_V4_MAP_PASS: return 0; case CIPSO_V4_MAP_TRANS: - if (doi_def->map.std->lvl.cipso[level] < CIPSO_V4_INV_LVL) + if ((level < doi_def->map.std->lvl.cipso_size) && + (doi_def->map.std->lvl.cipso[level] < CIPSO_V4_INV_LVL)) return 0; break; } @@ -1735,13 +1736,26 @@ validate_return: */ void cipso_v4_error(struct sk_buff *skb, int error, u32 gateway) { + unsigned char optbuf[sizeof(struct ip_options) + 40]; + struct ip_options *opt = (struct ip_options *)optbuf; + if (ip_hdr(skb)->protocol == IPPROTO_ICMP || error != -EACCES) return; + /* + * We might be called above the IP layer, + * so we can not use icmp_send and IPCB here. + */ + + memset(opt, 0, sizeof(struct ip_options)); + opt->optlen = ip_hdr(skb)->ihl*4 - sizeof(struct iphdr); + if (__ip_options_compile(dev_net(skb->dev), opt, skb, NULL)) + return; + if (gateway) - icmp_send(skb, ICMP_DEST_UNREACH, ICMP_NET_ANO, 0); + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_NET_ANO, 0, opt); else - icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_ANO, 0); + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_ANO, 0, opt); } /** diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c index e258a00b4a3d..eb514f312e6f 100644 --- a/net/ipv4/devinet.c +++ b/net/ipv4/devinet.c @@ -2063,13 +2063,49 @@ static const struct nla_policy devconf_ipv4_policy[NETCONFA_MAX+1] = { [NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN] = { .len = sizeof(int) }, }; +static int inet_netconf_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for netconf get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(struct netconfmsg), tb, + NETCONFA_MAX, devconf_ipv4_policy, extack); + + err = nlmsg_parse_strict(nlh, sizeof(struct netconfmsg), tb, + NETCONFA_MAX, devconf_ipv4_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETCONFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETCONFA_IFINDEX: + break; + default: + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in netconf get request"); + return -EINVAL; + } + } + + return 0; +} + static int inet_netconf_get_devconf(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { struct net *net = sock_net(in_skb->sk); struct nlattr *tb[NETCONFA_MAX+1]; - struct netconfmsg *ncm; struct sk_buff *skb; struct ipv4_devconf *devconf; struct in_device *in_dev; @@ -2077,9 +2113,8 @@ static int inet_netconf_get_devconf(struct sk_buff *in_skb, int ifindex; int err; - err = nlmsg_parse(nlh, sizeof(*ncm), tb, NETCONFA_MAX, - devconf_ipv4_policy, extack); - if (err < 0) + err = inet_netconf_valid_get_req(in_skb, nlh, tb, extack); + if (err) goto errout; err = -EINVAL; @@ -2556,32 +2591,34 @@ static __net_init int devinet_init_net(struct net *net) int err; struct ipv4_devconf *all, *dflt; #ifdef CONFIG_SYSCTL - struct ctl_table *tbl = ctl_forward_entry; + struct ctl_table *tbl; struct ctl_table_header *forw_hdr; #endif err = -ENOMEM; - all = &ipv4_devconf; - dflt = &ipv4_devconf_dflt; + all = kmemdup(&ipv4_devconf, sizeof(ipv4_devconf), GFP_KERNEL); + if (!all) + goto err_alloc_all; - if (!net_eq(net, &init_net)) { - all = kmemdup(all, sizeof(ipv4_devconf), GFP_KERNEL); - if (!all) - goto err_alloc_all; - - dflt = kmemdup(dflt, sizeof(ipv4_devconf_dflt), GFP_KERNEL); - if (!dflt) - goto err_alloc_dflt; + dflt = kmemdup(&ipv4_devconf_dflt, sizeof(ipv4_devconf_dflt), GFP_KERNEL); + if (!dflt) + goto err_alloc_dflt; #ifdef CONFIG_SYSCTL - tbl = kmemdup(tbl, sizeof(ctl_forward_entry), GFP_KERNEL); - if (!tbl) - goto err_alloc_ctl; + tbl = kmemdup(ctl_forward_entry, sizeof(ctl_forward_entry), GFP_KERNEL); + if (!tbl) + goto err_alloc_ctl; - tbl[0].data = &all->data[IPV4_DEVCONF_FORWARDING - 1]; - tbl[0].extra1 = all; - tbl[0].extra2 = net; + tbl[0].data = &all->data[IPV4_DEVCONF_FORWARDING - 1]; + tbl[0].extra1 = all; + tbl[0].extra2 = net; #endif + + if ((!IS_ENABLED(CONFIG_SYSCTL) || + sysctl_devconf_inherit_init_net != 2) && + !net_eq(net, &init_net)) { + memcpy(all, init_net.ipv4.devconf_all, sizeof(ipv4_devconf)); + memcpy(dflt, init_net.ipv4.devconf_dflt, sizeof(ipv4_devconf_dflt)); } #ifdef CONFIG_SYSCTL @@ -2611,15 +2648,12 @@ err_reg_ctl: err_reg_dflt: __devinet_sysctl_unregister(net, all, NETCONFA_IFINDEX_ALL); err_reg_all: - if (tbl != ctl_forward_entry) - kfree(tbl); + kfree(tbl); err_alloc_ctl: #endif - if (dflt != &ipv4_devconf_dflt) - kfree(dflt); + kfree(dflt); err_alloc_dflt: - if (all != &ipv4_devconf) - kfree(all); + kfree(all); err_alloc_all: return err; } diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c index 5459f41fc26f..10e809b296ec 100644 --- a/net/ipv4/esp4.c +++ b/net/ipv4/esp4.c @@ -328,7 +328,7 @@ int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info * skb->len += tailen; skb->data_len += tailen; skb->truesize += tailen; - if (sk) + if (sk && sk_fullsock(sk)) refcount_add(tailen, &sk->sk_wmem_alloc); goto out; diff --git a/net/ipv4/fib_frontend.c b/net/ipv4/fib_frontend.c index fe4f6a624238..ed14ec245584 100644 --- a/net/ipv4/fib_frontend.c +++ b/net/ipv4/fib_frontend.c @@ -710,6 +710,10 @@ static int rtm_to_fib_config(struct net *net, struct sk_buff *skb, case RTA_GATEWAY: cfg->fc_gw = nla_get_be32(attr); break; + case RTA_VIA: + NL_SET_ERR_MSG(extack, "IPv4 does not support RTA_VIA attribute"); + err = -EINVAL; + goto errout; case RTA_PRIORITY: cfg->fc_priority = nla_get_u32(attr); break; diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c index 5022bc63863a..8e185b5a2bf6 100644 --- a/net/ipv4/fib_semantics.c +++ b/net/ipv4/fib_semantics.c @@ -1072,7 +1072,7 @@ struct fib_info *fib_create_info(struct fib_config *cfg, goto failure; } - fi = kzalloc(sizeof(*fi)+nhs*sizeof(struct fib_nh), GFP_KERNEL); + fi = kzalloc(struct_size(fi, fib_nh, nhs), GFP_KERNEL); if (!fi) goto failure; fi->fib_metrics = ip_fib_metrics_init(fi->fib_net, cfg->fc_mx, diff --git a/net/ipv4/fou.c b/net/ipv4/fou.c index 437070d1ffb1..79e98e21cdd7 100644 --- a/net/ipv4/fou.c +++ b/net/ipv4/fou.c @@ -1024,7 +1024,7 @@ static int gue_err(struct sk_buff *skb, u32 info) int ret; len = sizeof(struct udphdr) + sizeof(struct guehdr); - if (!pskb_may_pull(skb, len)) + if (!pskb_may_pull(skb, transport_offset + len)) return -EINVAL; guehdr = (struct guehdr *)&udp_hdr(skb)[1]; @@ -1059,7 +1059,7 @@ static int gue_err(struct sk_buff *skb, u32 info) optlen = guehdr->hlen << 2; - if (!pskb_may_pull(skb, len + optlen)) + if (!pskb_may_pull(skb, transport_offset + len + optlen)) return -EINVAL; guehdr = (struct guehdr *)&udp_hdr(skb)[1]; diff --git a/net/ipv4/gre_demux.c b/net/ipv4/gre_demux.c index a4bf22ee3aed..7c4a41dc04bb 100644 --- a/net/ipv4/gre_demux.c +++ b/net/ipv4/gre_demux.c @@ -25,6 +25,7 @@ #include <linux/spinlock.h> #include <net/protocol.h> #include <net/gre.h> +#include <net/erspan.h> #include <net/icmp.h> #include <net/route.h> @@ -119,6 +120,22 @@ int gre_parse_header(struct sk_buff *skb, struct tnl_ptk_info *tpi, hdr_len += 4; } tpi->hdr_len = hdr_len; + + /* ERSPAN ver 1 and 2 protocol sets GRE key field + * to 0 and sets the configured key in the + * inner erspan header field + */ + if (greh->protocol == htons(ETH_P_ERSPAN) || + greh->protocol == htons(ETH_P_ERSPAN2)) { + struct erspan_base_hdr *ershdr; + + if (!pskb_may_pull(skb, nhs + hdr_len + sizeof(*ershdr))) + return -EINVAL; + + ershdr = (struct erspan_base_hdr *)options; + tpi->key = cpu_to_be32(get_session_id(ershdr)); + } + return hdr_len; } EXPORT_SYMBOL(gre_parse_header); diff --git a/net/ipv4/icmp.c b/net/ipv4/icmp.c index 065997f414e6..f3a5893b1e86 100644 --- a/net/ipv4/icmp.c +++ b/net/ipv4/icmp.c @@ -570,7 +570,8 @@ relookup_failed: * MUST reply to only the first fragment. */ -void icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info) +void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, + const struct ip_options *opt) { struct iphdr *iph; int room; @@ -691,7 +692,7 @@ void icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info) iph->tos; mark = IP4_REPLY_MARK(net, skb_in->mark); - if (ip_options_echo(net, &icmp_param.replyopts.opt.opt, skb_in)) + if (__ip_options_echo(net, &icmp_param.replyopts.opt.opt, skb_in, opt)) goto out_unlock; @@ -742,7 +743,7 @@ out_bh_enable: local_bh_enable(); out:; } -EXPORT_SYMBOL(icmp_send); +EXPORT_SYMBOL(__icmp_send); static void icmp_socket_deliver(struct sk_buff *skb, u32 info) @@ -1245,9 +1246,7 @@ static int __net_init icmp_sk_init(struct net *net) return 0; fail: - for_each_possible_cpu(i) - inet_ctl_sock_destroy(*per_cpu_ptr(net->ipv4.icmp_sk, i)); - free_percpu(net->ipv4.icmp_sk); + icmp_sk_exit(net); return err; } diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c index 765b2b32c4a4..6c2febc39dca 100644 --- a/net/ipv4/igmp.c +++ b/net/ipv4/igmp.c @@ -159,7 +159,8 @@ static int unsolicited_report_interval(struct in_device *in_dev) return interval_jiffies; } -static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im); +static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im, + gfp_t gfp); static void igmpv3_del_delrec(struct in_device *in_dev, struct ip_mc_list *im); static void igmpv3_clear_delrec(struct in_device *in_dev); static int sf_setstate(struct ip_mc_list *pmc); @@ -1145,7 +1146,8 @@ static void ip_mc_filter_del(struct in_device *in_dev, __be32 addr) /* * deleted ip_mc_list manipulation */ -static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im) +static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im, + gfp_t gfp) { struct ip_mc_list *pmc; struct net *net = dev_net(in_dev->dev); @@ -1156,7 +1158,7 @@ static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im) * for deleted items allows change reports to use common code with * non-deleted or query-response MCA's. */ - pmc = kzalloc(sizeof(*pmc), GFP_KERNEL); + pmc = kzalloc(sizeof(*pmc), gfp); if (!pmc) return; spin_lock_init(&pmc->lock); @@ -1261,7 +1263,7 @@ static void igmpv3_clear_delrec(struct in_device *in_dev) } #endif -static void igmp_group_dropped(struct ip_mc_list *im) +static void __igmp_group_dropped(struct ip_mc_list *im, gfp_t gfp) { struct in_device *in_dev = im->interface; #ifdef CONFIG_IP_MULTICAST @@ -1292,13 +1294,18 @@ static void igmp_group_dropped(struct ip_mc_list *im) return; } /* IGMPv3 */ - igmpv3_add_delrec(in_dev, im); + igmpv3_add_delrec(in_dev, im, gfp); igmp_ifc_event(in_dev); } #endif } +static void igmp_group_dropped(struct ip_mc_list *im) +{ + __igmp_group_dropped(im, GFP_KERNEL); +} + static void igmp_group_added(struct ip_mc_list *im) { struct in_device *in_dev = im->interface; @@ -1400,8 +1407,8 @@ static void ip_mc_hash_remove(struct in_device *in_dev, /* * A socket has joined a multicast group on device dev. */ -static void __ip_mc_inc_group(struct in_device *in_dev, __be32 addr, - unsigned int mode) +static void ____ip_mc_inc_group(struct in_device *in_dev, __be32 addr, + unsigned int mode, gfp_t gfp) { struct ip_mc_list *im; @@ -1415,7 +1422,7 @@ static void __ip_mc_inc_group(struct in_device *in_dev, __be32 addr, } } - im = kzalloc(sizeof(*im), GFP_KERNEL); + im = kzalloc(sizeof(*im), gfp); if (!im) goto out; @@ -1448,6 +1455,12 @@ out: return; } +void __ip_mc_inc_group(struct in_device *in_dev, __be32 addr, gfp_t gfp) +{ + ____ip_mc_inc_group(in_dev, addr, MCAST_EXCLUDE, gfp); +} +EXPORT_SYMBOL(__ip_mc_inc_group); + void ip_mc_inc_group(struct in_device *in_dev, __be32 addr) { __ip_mc_inc_group(in_dev, addr, MCAST_EXCLUDE); @@ -1493,22 +1506,22 @@ static int ip_mc_check_igmp_reportv3(struct sk_buff *skb) len += sizeof(struct igmpv3_report); - return pskb_may_pull(skb, len) ? 0 : -EINVAL; + return ip_mc_may_pull(skb, len) ? 0 : -EINVAL; } static int ip_mc_check_igmp_query(struct sk_buff *skb) { - unsigned int len = skb_transport_offset(skb); - - len += sizeof(struct igmphdr); - if (skb->len < len) - return -EINVAL; + unsigned int transport_len = ip_transport_len(skb); + unsigned int len; /* IGMPv{1,2}? */ - if (skb->len != len) { + if (transport_len != sizeof(struct igmphdr)) { /* or IGMPv3? */ - len += sizeof(struct igmpv3_query) - sizeof(struct igmphdr); - if (skb->len < len || !pskb_may_pull(skb, len)) + if (transport_len < sizeof(struct igmpv3_query)) + return -EINVAL; + + len = skb_transport_offset(skb) + sizeof(struct igmpv3_query); + if (!ip_mc_may_pull(skb, len)) return -EINVAL; } @@ -1528,7 +1541,6 @@ static int ip_mc_check_igmp_msg(struct sk_buff *skb) case IGMP_HOST_LEAVE_MESSAGE: case IGMP_HOST_MEMBERSHIP_REPORT: case IGMPV2_HOST_MEMBERSHIP_REPORT: - /* fall through */ return 0; case IGMPV3_HOST_MEMBERSHIP_REPORT: return ip_mc_check_igmp_reportv3(skb); @@ -1544,47 +1556,29 @@ static inline __sum16 ip_mc_validate_checksum(struct sk_buff *skb) return skb_checksum_simple_validate(skb); } -static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed) - +static int ip_mc_check_igmp_csum(struct sk_buff *skb) { - struct sk_buff *skb_chk; - unsigned int transport_len; unsigned int len = skb_transport_offset(skb) + sizeof(struct igmphdr); - int ret = -EINVAL; + unsigned int transport_len = ip_transport_len(skb); + struct sk_buff *skb_chk; - transport_len = ntohs(ip_hdr(skb)->tot_len) - ip_hdrlen(skb); + if (!ip_mc_may_pull(skb, len)) + return -EINVAL; skb_chk = skb_checksum_trimmed(skb, transport_len, ip_mc_validate_checksum); if (!skb_chk) - goto err; - - if (!pskb_may_pull(skb_chk, len)) - goto err; - - ret = ip_mc_check_igmp_msg(skb_chk); - if (ret) - goto err; - - if (skb_trimmed) - *skb_trimmed = skb_chk; - /* free now unneeded clone */ - else if (skb_chk != skb) - kfree_skb(skb_chk); - - ret = 0; + return -EINVAL; -err: - if (ret && skb_chk && skb_chk != skb) + if (skb_chk != skb) kfree_skb(skb_chk); - return ret; + return 0; } /** * ip_mc_check_igmp - checks whether this is a sane IGMP packet * @skb: the skb to validate - * @skb_trimmed: to store an skb pointer trimmed to IPv4 packet tail (optional) * * Checks whether an IPv4 packet is a valid IGMP packet. If so sets * skb transport header accordingly and returns zero. @@ -1594,18 +1588,10 @@ err: * -ENOMSG: IP header validation succeeded but it is not an IGMP packet. * -ENOMEM: A memory allocation failure happened. * - * Optionally, an skb pointer might be provided via skb_trimmed (or set it - * to NULL): After parsing an IGMP packet successfully it will point to - * an skb which has its tail aligned to the IP packet end. This might - * either be the originally provided skb or a trimmed, cloned version if - * the skb frame had data beyond the IP packet. A cloned skb allows us - * to leave the original skb and its full frame unchanged (which might be - * desirable for layer 2 frame jugglers). - * * Caller needs to set the skb network header and free any returned skb if it * differs from the provided skb. */ -int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed) +int ip_mc_check_igmp(struct sk_buff *skb) { int ret = ip_mc_check_iphdr(skb); @@ -1615,7 +1601,11 @@ int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed) if (ip_hdr(skb)->protocol != IPPROTO_IGMP) return -ENOMSG; - return __ip_mc_check_igmp(skb, skb_trimmed); + ret = ip_mc_check_igmp_csum(skb); + if (ret < 0) + return ret; + + return ip_mc_check_igmp_msg(skb); } EXPORT_SYMBOL(ip_mc_check_igmp); @@ -1656,7 +1646,7 @@ static void ip_mc_rejoin_groups(struct in_device *in_dev) * A socket has left a multicast group on device dev */ -void ip_mc_dec_group(struct in_device *in_dev, __be32 addr) +void __ip_mc_dec_group(struct in_device *in_dev, __be32 addr, gfp_t gfp) { struct ip_mc_list *i; struct ip_mc_list __rcu **ip; @@ -1671,7 +1661,7 @@ void ip_mc_dec_group(struct in_device *in_dev, __be32 addr) ip_mc_hash_remove(in_dev, i); *ip = i->next_rcu; in_dev->mc_count--; - igmp_group_dropped(i); + __igmp_group_dropped(i, gfp); ip_mc_clear_src(i); if (!in_dev->dead) @@ -1684,7 +1674,7 @@ void ip_mc_dec_group(struct in_device *in_dev, __be32 addr) } } } -EXPORT_SYMBOL(ip_mc_dec_group); +EXPORT_SYMBOL(__ip_mc_dec_group); /* Device changing type */ diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 1a4e9ff02762..5731670c560b 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -108,6 +108,7 @@ static size_t inet_sk_attr_size(struct sock *sk, + nla_total_size(1) /* INET_DIAG_TOS */ + nla_total_size(1) /* INET_DIAG_TCLASS */ + nla_total_size(4) /* INET_DIAG_MARK */ + + nla_total_size(4) /* INET_DIAG_CLASS_ID */ + nla_total_size(sizeof(struct inet_diag_meminfo)) + nla_total_size(sizeof(struct inet_diag_msg)) + nla_total_size(SK_MEMINFO_VARS * sizeof(u32)) @@ -287,12 +288,19 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, goto errout; } - if (ext & (1 << (INET_DIAG_CLASS_ID - 1))) { + if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) || + ext & (1 << (INET_DIAG_TCLASS - 1))) { u32 classid = 0; #ifdef CONFIG_SOCK_CGROUP_DATA classid = sock_cgroup_classid(&sk->sk_cgrp_data); #endif + /* Fallback to socket priority if class id isn't set. + * Classful qdiscs use it as direct reference to class. + * For cgroup2 classid is always zero. + */ + if (!classid) + classid = sk->sk_priority; if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid)) goto errout; diff --git a/net/ipv4/inet_fragment.c b/net/ipv4/inet_fragment.c index 760a9e52e02b..737808e27f8b 100644 --- a/net/ipv4/inet_fragment.c +++ b/net/ipv4/inet_fragment.c @@ -25,6 +25,62 @@ #include <net/sock.h> #include <net/inet_frag.h> #include <net/inet_ecn.h> +#include <net/ip.h> +#include <net/ipv6.h> + +/* Use skb->cb to track consecutive/adjacent fragments coming at + * the end of the queue. Nodes in the rb-tree queue will + * contain "runs" of one or more adjacent fragments. + * + * Invariants: + * - next_frag is NULL at the tail of a "run"; + * - the head of a "run" has the sum of all fragment lengths in frag_run_len. + */ +struct ipfrag_skb_cb { + union { + struct inet_skb_parm h4; + struct inet6_skb_parm h6; + }; + struct sk_buff *next_frag; + int frag_run_len; +}; + +#define FRAG_CB(skb) ((struct ipfrag_skb_cb *)((skb)->cb)) + +static void fragcb_clear(struct sk_buff *skb) +{ + RB_CLEAR_NODE(&skb->rbnode); + FRAG_CB(skb)->next_frag = NULL; + FRAG_CB(skb)->frag_run_len = skb->len; +} + +/* Append skb to the last "run". */ +static void fragrun_append_to_last(struct inet_frag_queue *q, + struct sk_buff *skb) +{ + fragcb_clear(skb); + + FRAG_CB(q->last_run_head)->frag_run_len += skb->len; + FRAG_CB(q->fragments_tail)->next_frag = skb; + q->fragments_tail = skb; +} + +/* Create a new "run" with the skb. */ +static void fragrun_create(struct inet_frag_queue *q, struct sk_buff *skb) +{ + BUILD_BUG_ON(sizeof(struct ipfrag_skb_cb) > sizeof(skb->cb)); + fragcb_clear(skb); + + if (q->last_run_head) + rb_link_node(&skb->rbnode, &q->last_run_head->rbnode, + &q->last_run_head->rbnode.rb_right); + else + rb_link_node(&skb->rbnode, NULL, &q->rb_fragments.rb_node); + rb_insert_color(&skb->rbnode, &q->rb_fragments); + + q->fragments_tail = skb; + q->last_run_head = skb; +} /* Given the OR values of all fragments, apply RFC 3168 5.3 requirements * Value : 0xff if frame should be dropped. @@ -123,9 +179,30 @@ static void inet_frag_destroy_rcu(struct rcu_head *head) kmem_cache_free(f->frags_cachep, q); } +unsigned int inet_frag_rbtree_purge(struct rb_root *root) +{ + struct rb_node *p = rb_first(root); + unsigned int sum = 0; + + while (p) { + struct sk_buff *skb = rb_entry(p, struct sk_buff, rbnode); + + p = rb_next(p); + rb_erase(&skb->rbnode, root); + while (skb) { + struct sk_buff *next = FRAG_CB(skb)->next_frag; + + sum += skb->truesize; + kfree_skb(skb); + skb = next; + } + } + return sum; +} +EXPORT_SYMBOL(inet_frag_rbtree_purge); + void inet_frag_destroy(struct inet_frag_queue *q) { - struct sk_buff *fp; struct netns_frags *nf; unsigned int sum, sum_truesize = 0; struct inet_frags *f; @@ -134,20 +211,9 @@ void inet_frag_destroy(struct inet_frag_queue *q) WARN_ON(del_timer(&q->timer) != 0); /* Release all fragment data. */ - fp = q->fragments; nf = q->net; f = nf->f; - if (fp) { - do { - struct sk_buff *xp = fp->next; - - sum_truesize += fp->truesize; - kfree_skb(fp); - fp = xp; - } while (fp); - } else { - sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments); - } + sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments); sum = sum_truesize + f->qsize; call_rcu(&q->rcu, inet_frag_destroy_rcu); @@ -224,3 +290,212 @@ struct inet_frag_queue *inet_frag_find(struct netns_frags *nf, void *key) return fq; } EXPORT_SYMBOL(inet_frag_find); + +int inet_frag_queue_insert(struct inet_frag_queue *q, struct sk_buff *skb, + int offset, int end) +{ + struct sk_buff *last = q->fragments_tail; + + /* RFC5722, Section 4, amended by Errata ID : 3089 + * When reassembling an IPv6 datagram, if + * one or more its constituent fragments is determined to be an + * overlapping fragment, the entire datagram (and any constituent + * fragments) MUST be silently discarded. + * + * Duplicates, however, should be ignored (i.e. skb dropped, but the + * queue/fragments kept for later reassembly). + */ + if (!last) + fragrun_create(q, skb); /* First fragment. */ + else if (last->ip_defrag_offset + last->len < end) { + /* This is the common case: skb goes to the end. */ + /* Detect and discard overlaps. */ + if (offset < last->ip_defrag_offset + last->len) + return IPFRAG_OVERLAP; + if (offset == last->ip_defrag_offset + last->len) + fragrun_append_to_last(q, skb); + else + fragrun_create(q, skb); + } else { + /* Binary search. Note that skb can become the first fragment, + * but not the last (covered above). + */ + struct rb_node **rbn, *parent; + + rbn = &q->rb_fragments.rb_node; + do { + struct sk_buff *curr; + int curr_run_end; + + parent = *rbn; + curr = rb_to_skb(parent); + curr_run_end = curr->ip_defrag_offset + + FRAG_CB(curr)->frag_run_len; + if (end <= curr->ip_defrag_offset) + rbn = &parent->rb_left; + else if (offset >= curr_run_end) + rbn = &parent->rb_right; + else if (offset >= curr->ip_defrag_offset && + end <= curr_run_end) + return IPFRAG_DUP; + else + return IPFRAG_OVERLAP; + } while (*rbn); + /* Here we have parent properly set, and rbn pointing to + * one of its NULL left/right children. Insert skb. + */ + fragcb_clear(skb); + rb_link_node(&skb->rbnode, parent, rbn); + rb_insert_color(&skb->rbnode, &q->rb_fragments); + } + + skb->ip_defrag_offset = offset; + + return IPFRAG_OK; +} +EXPORT_SYMBOL(inet_frag_queue_insert); + +void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, + struct sk_buff *parent) +{ + struct sk_buff *fp, *head = skb_rb_first(&q->rb_fragments); + struct sk_buff **nextp; + int delta; + + if (head != skb) { + fp = skb_clone(skb, GFP_ATOMIC); + if (!fp) + return NULL; + FRAG_CB(fp)->next_frag = FRAG_CB(skb)->next_frag; + if (RB_EMPTY_NODE(&skb->rbnode)) + FRAG_CB(parent)->next_frag = fp; + else + rb_replace_node(&skb->rbnode, &fp->rbnode, + &q->rb_fragments); + if (q->fragments_tail == skb) + q->fragments_tail = fp; + skb_morph(skb, head); + FRAG_CB(skb)->next_frag = FRAG_CB(head)->next_frag; + rb_replace_node(&head->rbnode, &skb->rbnode, + &q->rb_fragments); + consume_skb(head); + head = skb; + } + WARN_ON(head->ip_defrag_offset != 0); + + delta = -head->truesize; + + /* Head of list must not be cloned. */ + if (skb_unclone(head, GFP_ATOMIC)) + return NULL; + + delta += head->truesize; + if (delta) + add_frag_mem_limit(q->net, delta); + + /* If the first fragment is fragmented itself, we split + * it to two chunks: the first with data and paged part + * and the second, holding only fragments. + */ + if (skb_has_frag_list(head)) { + struct sk_buff *clone; + int i, plen = 0; + + clone = alloc_skb(0, GFP_ATOMIC); + if (!clone) + return NULL; + skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; + skb_frag_list_init(head); + for (i = 0; i < skb_shinfo(head)->nr_frags; i++) + plen += skb_frag_size(&skb_shinfo(head)->frags[i]); + clone->data_len = head->data_len - plen; + clone->len = clone->data_len; + head->truesize += clone->truesize; + clone->csum = 0; + clone->ip_summed = head->ip_summed; + add_frag_mem_limit(q->net, clone->truesize); + skb_shinfo(head)->frag_list = clone; + nextp = &clone->next; + } else { + nextp = &skb_shinfo(head)->frag_list; + } + + return nextp; +} +EXPORT_SYMBOL(inet_frag_reasm_prepare); + +void inet_frag_reasm_finish(struct inet_frag_queue *q, struct sk_buff *head, + void *reasm_data) +{ + struct sk_buff **nextp = (struct sk_buff **)reasm_data; + struct rb_node *rbn; + struct sk_buff *fp; + + skb_push(head, head->data - skb_network_header(head)); + + /* Traverse the tree in order, to build frag_list. */ + fp = FRAG_CB(head)->next_frag; + rbn = rb_next(&head->rbnode); + rb_erase(&head->rbnode, &q->rb_fragments); + while (rbn || fp) { + /* fp points to the next sk_buff in the current run; + * rbn points to the next run. + */ + /* Go through the current run. */ + while (fp) { + *nextp = fp; + nextp = &fp->next; + fp->prev = NULL; + memset(&fp->rbnode, 0, sizeof(fp->rbnode)); + fp->sk = NULL; + head->data_len += fp->len; + head->len += fp->len; + if (head->ip_summed != fp->ip_summed) + head->ip_summed = CHECKSUM_NONE; + else if (head->ip_summed == CHECKSUM_COMPLETE) + head->csum = csum_add(head->csum, fp->csum); + head->truesize += fp->truesize; + fp = FRAG_CB(fp)->next_frag; + } + /* Move to the next run. */ + if (rbn) { + struct rb_node *rbnext = rb_next(rbn); + + fp = rb_to_skb(rbn); + rb_erase(rbn, &q->rb_fragments); + rbn = rbnext; + } + } + sub_frag_mem_limit(q->net, head->truesize); + + *nextp = NULL; + skb_mark_not_on_list(head); + head->prev = NULL; + head->tstamp = q->stamp; +} +EXPORT_SYMBOL(inet_frag_reasm_finish); + +struct sk_buff *inet_frag_pull_head(struct inet_frag_queue *q) +{ + struct sk_buff *head, *skb; + + head = skb_rb_first(&q->rb_fragments); + if (!head) + return NULL; + skb = FRAG_CB(head)->next_frag; + if (skb) + rb_replace_node(&head->rbnode, &skb->rbnode, + &q->rb_fragments); + else + rb_erase(&head->rbnode, &q->rb_fragments); + memset(&head->rbnode, 0, sizeof(head->rbnode)); + barrier(); + + if (head == q->fragments_tail) + q->fragments_tail = NULL; + + sub_frag_mem_limit(q->net, head->truesize); + + return head; +} +EXPORT_SYMBOL(inet_frag_pull_head); diff --git a/net/ipv4/inetpeer.c b/net/ipv4/inetpeer.c index d757b9642d0d..be778599bfed 100644 --- a/net/ipv4/inetpeer.c +++ b/net/ipv4/inetpeer.c @@ -216,6 +216,7 @@ struct inet_peer *inet_getpeer(struct inet_peer_base *base, atomic_set(&p->rid, 0); p->metrics[RTAX_LOCK-1] = INETPEER_METRICS_NEW; p->rate_tokens = 0; + p->n_redirects = 0; /* 60*HZ is arbitrary, but chosen enough high so that the first * calculation of tokens is at its maximum. */ diff --git a/net/ipv4/ip_fragment.c b/net/ipv4/ip_fragment.c index 867be8f7f1fa..cf2b0a6a3337 100644 --- a/net/ipv4/ip_fragment.c +++ b/net/ipv4/ip_fragment.c @@ -57,57 +57,6 @@ */ static const char ip_frag_cache_name[] = "ip4-frags"; -/* Use skb->cb to track consecutive/adjacent fragments coming at - * the end of the queue. Nodes in the rb-tree queue will - * contain "runs" of one or more adjacent fragments. - * - * Invariants: - * - next_frag is NULL at the tail of a "run"; - * - the head of a "run" has the sum of all fragment lengths in frag_run_len. - */ -struct ipfrag_skb_cb { - struct inet_skb_parm h; - struct sk_buff *next_frag; - int frag_run_len; -}; - -#define FRAG_CB(skb) ((struct ipfrag_skb_cb *)((skb)->cb)) - -static void ip4_frag_init_run(struct sk_buff *skb) -{ - BUILD_BUG_ON(sizeof(struct ipfrag_skb_cb) > sizeof(skb->cb)); - - FRAG_CB(skb)->next_frag = NULL; - FRAG_CB(skb)->frag_run_len = skb->len; -} - -/* Append skb to the last "run". */ -static void ip4_frag_append_to_last_run(struct inet_frag_queue *q, - struct sk_buff *skb) -{ - RB_CLEAR_NODE(&skb->rbnode); - FRAG_CB(skb)->next_frag = NULL; - - FRAG_CB(q->last_run_head)->frag_run_len += skb->len; - FRAG_CB(q->fragments_tail)->next_frag = skb; - q->fragments_tail = skb; -} - -/* Create a new "run" with the skb. */ -static void ip4_frag_create_run(struct inet_frag_queue *q, struct sk_buff *skb) -{ - if (q->last_run_head) - rb_link_node(&skb->rbnode, &q->last_run_head->rbnode, - &q->last_run_head->rbnode.rb_right); - else - rb_link_node(&skb->rbnode, NULL, &q->rb_fragments.rb_node); - rb_insert_color(&skb->rbnode, &q->rb_fragments); - - ip4_frag_init_run(skb); - q->fragments_tail = skb; - q->last_run_head = skb; -} - /* Describe an entry in the "incomplete datagrams" queue. */ struct ipq { struct inet_frag_queue q; @@ -212,27 +161,9 @@ static void ip_expire(struct timer_list *t) * pull the head out of the tree in order to be able to * deal with head->dev. */ - if (qp->q.fragments) { - head = qp->q.fragments; - qp->q.fragments = head->next; - } else { - head = skb_rb_first(&qp->q.rb_fragments); - if (!head) - goto out; - if (FRAG_CB(head)->next_frag) - rb_replace_node(&head->rbnode, - &FRAG_CB(head)->next_frag->rbnode, - &qp->q.rb_fragments); - else - rb_erase(&head->rbnode, &qp->q.rb_fragments); - memset(&head->rbnode, 0, sizeof(head->rbnode)); - barrier(); - } - if (head == qp->q.fragments_tail) - qp->q.fragments_tail = NULL; - - sub_frag_mem_limit(qp->q.net, head->truesize); - + head = inet_frag_pull_head(&qp->q); + if (!head) + goto out; head->dev = dev_get_by_index_rcu(net, qp->iif); if (!head->dev) goto out; @@ -330,7 +261,6 @@ static int ip_frag_reinit(struct ipq *qp) qp->q.flags = 0; qp->q.len = 0; qp->q.meat = 0; - qp->q.fragments = NULL; qp->q.rb_fragments = RB_ROOT; qp->q.fragments_tail = NULL; qp->q.last_run_head = NULL; @@ -344,12 +274,10 @@ static int ip_frag_reinit(struct ipq *qp) static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) { struct net *net = container_of(qp->q.net, struct net, ipv4.frags); - struct rb_node **rbn, *parent; - struct sk_buff *skb1, *prev_tail; - int ihl, end, skb1_run_end; + int ihl, end, flags, offset; + struct sk_buff *prev_tail; struct net_device *dev; unsigned int fragsize; - int flags, offset; int err = -ENOENT; u8 ecn; @@ -413,62 +341,13 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) /* Makes sure compiler wont do silly aliasing games */ barrier(); - /* RFC5722, Section 4, amended by Errata ID : 3089 - * When reassembling an IPv6 datagram, if - * one or more its constituent fragments is determined to be an - * overlapping fragment, the entire datagram (and any constituent - * fragments) MUST be silently discarded. - * - * We do the same here for IPv4 (and increment an snmp counter) but - * we do not want to drop the whole queue in response to a duplicate - * fragment. - */ - - err = -EINVAL; - /* Find out where to put this fragment. */ prev_tail = qp->q.fragments_tail; - if (!prev_tail) - ip4_frag_create_run(&qp->q, skb); /* First fragment. */ - else if (prev_tail->ip_defrag_offset + prev_tail->len < end) { - /* This is the common case: skb goes to the end. */ - /* Detect and discard overlaps. */ - if (offset < prev_tail->ip_defrag_offset + prev_tail->len) - goto overlap; - if (offset == prev_tail->ip_defrag_offset + prev_tail->len) - ip4_frag_append_to_last_run(&qp->q, skb); - else - ip4_frag_create_run(&qp->q, skb); - } else { - /* Binary search. Note that skb can become the first fragment, - * but not the last (covered above). - */ - rbn = &qp->q.rb_fragments.rb_node; - do { - parent = *rbn; - skb1 = rb_to_skb(parent); - skb1_run_end = skb1->ip_defrag_offset + - FRAG_CB(skb1)->frag_run_len; - if (end <= skb1->ip_defrag_offset) - rbn = &parent->rb_left; - else if (offset >= skb1_run_end) - rbn = &parent->rb_right; - else if (offset >= skb1->ip_defrag_offset && - end <= skb1_run_end) - goto err; /* No new data, potential duplicate */ - else - goto overlap; /* Found an overlap */ - } while (*rbn); - /* Here we have parent properly set, and rbn pointing to - * one of its NULL left/right children. Insert skb. - */ - ip4_frag_init_run(skb); - rb_link_node(&skb->rbnode, parent, rbn); - rb_insert_color(&skb->rbnode, &qp->q.rb_fragments); - } + err = inet_frag_queue_insert(&qp->q, skb, offset, end); + if (err) + goto insert_error; if (dev) qp->iif = dev->ifindex; - skb->ip_defrag_offset = offset; qp->q.stamp = skb->tstamp; qp->q.meat += skb->len; @@ -501,10 +380,16 @@ static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) skb_dst_drop(skb); return -EINPROGRESS; -overlap: +insert_error: + if (err == IPFRAG_DUP) { + kfree_skb(skb); + return -EINVAL; + } + err = -EINVAL; __IP_INC_STATS(net, IPSTATS_MIB_REASM_OVERLAPS); discard_qp: inet_frag_kill(&qp->q); + __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); err: kfree_skb(skb); return err; @@ -516,13 +401,8 @@ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, { struct net *net = container_of(qp->q.net, struct net, ipv4.frags); struct iphdr *iph; - struct sk_buff *fp, *head = skb_rb_first(&qp->q.rb_fragments); - struct sk_buff **nextp; /* To build frag_list. */ - struct rb_node *rbn; - int len; - int ihlen; - int delta; - int err; + void *reasm_data; + int len, err; u8 ecn; ipq_kill(qp); @@ -532,117 +412,23 @@ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, err = -EINVAL; goto out_fail; } - /* Make the one we just received the head. */ - if (head != skb) { - fp = skb_clone(skb, GFP_ATOMIC); - if (!fp) - goto out_nomem; - FRAG_CB(fp)->next_frag = FRAG_CB(skb)->next_frag; - if (RB_EMPTY_NODE(&skb->rbnode)) - FRAG_CB(prev_tail)->next_frag = fp; - else - rb_replace_node(&skb->rbnode, &fp->rbnode, - &qp->q.rb_fragments); - if (qp->q.fragments_tail == skb) - qp->q.fragments_tail = fp; - skb_morph(skb, head); - FRAG_CB(skb)->next_frag = FRAG_CB(head)->next_frag; - rb_replace_node(&head->rbnode, &skb->rbnode, - &qp->q.rb_fragments); - consume_skb(head); - head = skb; - } - WARN_ON(head->ip_defrag_offset != 0); - - /* Allocate a new buffer for the datagram. */ - ihlen = ip_hdrlen(head); - len = ihlen + qp->q.len; + /* Make the one we just received the head. */ + reasm_data = inet_frag_reasm_prepare(&qp->q, skb, prev_tail); + if (!reasm_data) + goto out_nomem; + len = ip_hdrlen(skb) + qp->q.len; err = -E2BIG; if (len > 65535) goto out_oversize; - delta = - head->truesize; - - /* Head of list must not be cloned. */ - if (skb_unclone(head, GFP_ATOMIC)) - goto out_nomem; - - delta += head->truesize; - if (delta) - add_frag_mem_limit(qp->q.net, delta); - - /* If the first fragment is fragmented itself, we split - * it to two chunks: the first with data and paged part - * and the second, holding only fragments. */ - if (skb_has_frag_list(head)) { - struct sk_buff *clone; - int i, plen = 0; - - clone = alloc_skb(0, GFP_ATOMIC); - if (!clone) - goto out_nomem; - skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; - skb_frag_list_init(head); - for (i = 0; i < skb_shinfo(head)->nr_frags; i++) - plen += skb_frag_size(&skb_shinfo(head)->frags[i]); - clone->len = clone->data_len = head->data_len - plen; - head->truesize += clone->truesize; - clone->csum = 0; - clone->ip_summed = head->ip_summed; - add_frag_mem_limit(qp->q.net, clone->truesize); - skb_shinfo(head)->frag_list = clone; - nextp = &clone->next; - } else { - nextp = &skb_shinfo(head)->frag_list; - } - - skb_push(head, head->data - skb_network_header(head)); + inet_frag_reasm_finish(&qp->q, skb, reasm_data); - /* Traverse the tree in order, to build frag_list. */ - fp = FRAG_CB(head)->next_frag; - rbn = rb_next(&head->rbnode); - rb_erase(&head->rbnode, &qp->q.rb_fragments); - while (rbn || fp) { - /* fp points to the next sk_buff in the current run; - * rbn points to the next run. - */ - /* Go through the current run. */ - while (fp) { - *nextp = fp; - nextp = &fp->next; - fp->prev = NULL; - memset(&fp->rbnode, 0, sizeof(fp->rbnode)); - fp->sk = NULL; - head->data_len += fp->len; - head->len += fp->len; - if (head->ip_summed != fp->ip_summed) - head->ip_summed = CHECKSUM_NONE; - else if (head->ip_summed == CHECKSUM_COMPLETE) - head->csum = csum_add(head->csum, fp->csum); - head->truesize += fp->truesize; - fp = FRAG_CB(fp)->next_frag; - } - /* Move to the next run. */ - if (rbn) { - struct rb_node *rbnext = rb_next(rbn); - - fp = rb_to_skb(rbn); - rb_erase(rbn, &qp->q.rb_fragments); - rbn = rbnext; - } - } - sub_frag_mem_limit(qp->q.net, head->truesize); - - *nextp = NULL; - skb_mark_not_on_list(head); - head->prev = NULL; - head->dev = dev; - head->tstamp = qp->q.stamp; - IPCB(head)->frag_max_size = max(qp->max_df_size, qp->q.max_size); + skb->dev = dev; + IPCB(skb)->frag_max_size = max(qp->max_df_size, qp->q.max_size); - iph = ip_hdr(head); + iph = ip_hdr(skb); iph->tot_len = htons(len); iph->tos |= ecn; @@ -655,7 +441,7 @@ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, * from one very small df-fragment and one large non-df frag. */ if (qp->max_df_size == qp->q.max_size) { - IPCB(head)->flags |= IPSKB_FRAG_PMTU; + IPCB(skb)->flags |= IPSKB_FRAG_PMTU; iph->frag_off = htons(IP_DF); } else { iph->frag_off = 0; @@ -664,7 +450,6 @@ static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, ip_send_check(iph); __IP_INC_STATS(net, IPSTATS_MIB_REASMOKS); - qp->q.fragments = NULL; qp->q.rb_fragments = RB_ROOT; qp->q.fragments_tail = NULL; qp->q.last_run_head = NULL; @@ -753,28 +538,6 @@ struct sk_buff *ip_check_defrag(struct net *net, struct sk_buff *skb, u32 user) } EXPORT_SYMBOL(ip_check_defrag); -unsigned int inet_frag_rbtree_purge(struct rb_root *root) -{ - struct rb_node *p = rb_first(root); - unsigned int sum = 0; - - while (p) { - struct sk_buff *skb = rb_entry(p, struct sk_buff, rbnode); - - p = rb_next(p); - rb_erase(&skb->rbnode, root); - while (skb) { - struct sk_buff *next = FRAG_CB(skb)->next_frag; - - sum += skb->truesize; - kfree_skb(skb); - skb = next; - } - } - return sum; -} -EXPORT_SYMBOL(inet_frag_rbtree_purge); - #ifdef CONFIG_SYSCTL static int dist_min; diff --git a/net/ipv4/ip_gre.c b/net/ipv4/ip_gre.c index b1a74d80d868..fd219f7bd3ea 100644 --- a/net/ipv4/ip_gre.c +++ b/net/ipv4/ip_gre.c @@ -268,20 +268,11 @@ static int erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi, int len; itn = net_generic(net, erspan_net_id); - len = gre_hdr_len + sizeof(*ershdr); - - /* Check based hdr len */ - if (unlikely(!pskb_may_pull(skb, len))) - return PACKET_REJECT; iph = ip_hdr(skb); ershdr = (struct erspan_base_hdr *)(skb->data + gre_hdr_len); ver = ershdr->ver; - /* The original GRE header does not have key field, - * Use ERSPAN 10-bit session ID as key. - */ - tpi->key = cpu_to_be32(get_session_id(ershdr)); tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, tpi->flags | TUNNEL_KEY, iph->saddr, iph->daddr, tpi->key); @@ -458,81 +449,14 @@ static int gre_handle_offloads(struct sk_buff *skb, bool csum) return iptunnel_handle_offloads(skb, csum ? SKB_GSO_GRE_CSUM : SKB_GSO_GRE); } -static struct rtable *gre_get_rt(struct sk_buff *skb, - struct net_device *dev, - struct flowi4 *fl, - const struct ip_tunnel_key *key) -{ - struct net *net = dev_net(dev); - - memset(fl, 0, sizeof(*fl)); - fl->daddr = key->u.ipv4.dst; - fl->saddr = key->u.ipv4.src; - fl->flowi4_tos = RT_TOS(key->tos); - fl->flowi4_mark = skb->mark; - fl->flowi4_proto = IPPROTO_GRE; - - return ip_route_output_key(net, fl); -} - -static struct rtable *prepare_fb_xmit(struct sk_buff *skb, - struct net_device *dev, - struct flowi4 *fl, - int tunnel_hlen) -{ - struct ip_tunnel_info *tun_info; - const struct ip_tunnel_key *key; - struct rtable *rt = NULL; - int min_headroom; - bool use_cache; - int err; - - tun_info = skb_tunnel_info(skb); - key = &tun_info->key; - use_cache = ip_tunnel_dst_cache_usable(skb, tun_info); - - if (use_cache) - rt = dst_cache_get_ip4(&tun_info->dst_cache, &fl->saddr); - if (!rt) { - rt = gre_get_rt(skb, dev, fl, key); - if (IS_ERR(rt)) - goto err_free_skb; - if (use_cache) - dst_cache_set_ip4(&tun_info->dst_cache, &rt->dst, - fl->saddr); - } - - min_headroom = LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len - + tunnel_hlen + sizeof(struct iphdr); - if (skb_headroom(skb) < min_headroom || skb_header_cloned(skb)) { - int head_delta = SKB_DATA_ALIGN(min_headroom - - skb_headroom(skb) + - 16); - err = pskb_expand_head(skb, max_t(int, head_delta, 0), - 0, GFP_ATOMIC); - if (unlikely(err)) - goto err_free_rt; - } - return rt; - -err_free_rt: - ip_rt_put(rt); -err_free_skb: - kfree_skb(skb); - dev->stats.tx_dropped++; - return NULL; -} - static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, __be16 proto) { struct ip_tunnel *tunnel = netdev_priv(dev); struct ip_tunnel_info *tun_info; const struct ip_tunnel_key *key; - struct rtable *rt = NULL; - struct flowi4 fl; int tunnel_hlen; - __be16 df, flags; + __be16 flags; tun_info = skb_tunnel_info(skb); if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || @@ -542,13 +466,12 @@ static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, key = &tun_info->key; tunnel_hlen = gre_calc_hlen(key->tun_flags); - rt = prepare_fb_xmit(skb, dev, &fl, tunnel_hlen); - if (!rt) - return; + if (skb_cow_head(skb, dev->needed_headroom)) + goto err_free_skb; /* Push Tunnel header. */ if (gre_handle_offloads(skb, !!(tun_info->key.tun_flags & TUNNEL_CSUM))) - goto err_free_rt; + goto err_free_skb; flags = tun_info->key.tun_flags & (TUNNEL_CSUM | TUNNEL_KEY | TUNNEL_SEQ); @@ -556,14 +479,10 @@ static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, tunnel_id_to_key32(tun_info->key.tun_id), (flags & TUNNEL_SEQ) ? htonl(tunnel->o_seqno++) : 0); - df = key->tun_flags & TUNNEL_DONT_FRAGMENT ? htons(IP_DF) : 0; + ip_md_tunnel_xmit(skb, dev, IPPROTO_GRE, tunnel_hlen); - iptunnel_xmit(skb->sk, rt, skb, fl.saddr, key->u.ipv4.dst, IPPROTO_GRE, - key->tos, key->ttl, df, false); return; -err_free_rt: - ip_rt_put(rt); err_free_skb: kfree_skb(skb); dev->stats.tx_dropped++; @@ -575,10 +494,8 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) struct ip_tunnel_info *tun_info; const struct ip_tunnel_key *key; struct erspan_metadata *md; - struct rtable *rt = NULL; bool truncate = false; - __be16 df, proto; - struct flowi4 fl; + __be16 proto; int tunnel_hlen; int version; int nhoff; @@ -591,21 +508,20 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) key = &tun_info->key; if (!(tun_info->key.tun_flags & TUNNEL_ERSPAN_OPT)) - goto err_free_rt; + goto err_free_skb; md = ip_tunnel_info_opts(tun_info); if (!md) - goto err_free_rt; + goto err_free_skb; /* ERSPAN has fixed 8 byte GRE header */ version = md->version; tunnel_hlen = 8 + erspan_hdr_len(version); - rt = prepare_fb_xmit(skb, dev, &fl, tunnel_hlen); - if (!rt) - return; + if (skb_cow_head(skb, dev->needed_headroom)) + goto err_free_skb; if (gre_handle_offloads(skb, false)) - goto err_free_rt; + goto err_free_skb; if (skb->len > dev->mtu + dev->hard_header_len) { pskb_trim(skb, dev->mtu + dev->hard_header_len); @@ -634,20 +550,16 @@ static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) truncate, true); proto = htons(ETH_P_ERSPAN2); } else { - goto err_free_rt; + goto err_free_skb; } gre_build_header(skb, 8, TUNNEL_SEQ, proto, 0, htonl(tunnel->o_seqno++)); - df = key->tun_flags & TUNNEL_DONT_FRAGMENT ? htons(IP_DF) : 0; + ip_md_tunnel_xmit(skb, dev, IPPROTO_GRE, tunnel_hlen); - iptunnel_xmit(skb->sk, rt, skb, fl.saddr, key->u.ipv4.dst, IPPROTO_GRE, - key->tos, key->ttl, df, false); return; -err_free_rt: - ip_rt_put(rt); err_free_skb: kfree_skb(skb); dev->stats.tx_dropped++; @@ -656,13 +568,18 @@ err_free_skb: static int gre_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb) { struct ip_tunnel_info *info = skb_tunnel_info(skb); + const struct ip_tunnel_key *key; struct rtable *rt; struct flowi4 fl4; if (ip_tunnel_info_af(info) != AF_INET) return -EINVAL; - rt = gre_get_rt(skb, dev, &fl4, &info->key); + key = &info->key; + ip_tunnel_init_flow(&fl4, IPPROTO_GRE, key->u.ipv4.dst, key->u.ipv4.src, + tunnel_id_to_key32(key->tun_id), key->tos, 0, + skb->mark, skb_get_hash(skb)); + rt = ip_route_output_key(dev_net(dev), &fl4); if (IS_ERR(rt)) return PTR_ERR(rt); @@ -1464,12 +1381,31 @@ static int ipgre_fill_info(struct sk_buff *skb, const struct net_device *dev) { struct ip_tunnel *t = netdev_priv(dev); struct ip_tunnel_parm *p = &t->parms; + __be16 o_flags = p->o_flags; + + if (t->erspan_ver == 1 || t->erspan_ver == 2) { + if (!t->collect_md) + o_flags |= TUNNEL_KEY; + + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) + goto nla_put_failure; + + if (t->erspan_ver == 1) { + if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) + goto nla_put_failure; + } else { + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) + goto nla_put_failure; + } + } if (nla_put_u32(skb, IFLA_GRE_LINK, p->link) || nla_put_be16(skb, IFLA_GRE_IFLAGS, gre_tnl_flags_to_gre_flags(p->i_flags)) || nla_put_be16(skb, IFLA_GRE_OFLAGS, - gre_tnl_flags_to_gre_flags(p->o_flags)) || + gre_tnl_flags_to_gre_flags(o_flags)) || nla_put_be32(skb, IFLA_GRE_IKEY, p->i_key) || nla_put_be32(skb, IFLA_GRE_OKEY, p->o_key) || nla_put_in_addr(skb, IFLA_GRE_LOCAL, p->iph.saddr) || @@ -1499,19 +1435,6 @@ static int ipgre_fill_info(struct sk_buff *skb, const struct net_device *dev) goto nla_put_failure; } - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) - goto nla_put_failure; - - if (t->erspan_ver == 1) { - if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) - goto nla_put_failure; - } else if (t->erspan_ver == 2) { - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) - goto nla_put_failure; - if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) - goto nla_put_failure; - } - return 0; nla_put_failure: diff --git a/net/ipv4/ip_input.c b/net/ipv4/ip_input.c index 51d8efba6de2..ecce2dc78f17 100644 --- a/net/ipv4/ip_input.c +++ b/net/ipv4/ip_input.c @@ -307,11 +307,10 @@ drop: } static int ip_rcv_finish_core(struct net *net, struct sock *sk, - struct sk_buff *skb) + struct sk_buff *skb, struct net_device *dev) { const struct iphdr *iph = ip_hdr(skb); int (*edemux)(struct sk_buff *skb); - struct net_device *dev = skb->dev; struct rtable *rt; int err; @@ -400,6 +399,7 @@ drop_error: static int ip_rcv_finish(struct net *net, struct sock *sk, struct sk_buff *skb) { + struct net_device *dev = skb->dev; int ret; /* if ingress device is enslaved to an L3 master device pass the @@ -409,7 +409,7 @@ static int ip_rcv_finish(struct net *net, struct sock *sk, struct sk_buff *skb) if (!skb) return NET_RX_SUCCESS; - ret = ip_rcv_finish_core(net, sk, skb); + ret = ip_rcv_finish_core(net, sk, skb, dev); if (ret != NET_RX_DROP) ret = dst_input(skb); return ret; @@ -429,7 +429,6 @@ static struct sk_buff *ip_rcv_core(struct sk_buff *skb, struct net *net) if (skb->pkt_type == PACKET_OTHERHOST) goto drop; - __IP_UPD_PO_STATS(net, IPSTATS_MIB_IN, skb->len); skb = skb_share_check(skb, GFP_ATOMIC); @@ -521,6 +520,7 @@ int ip_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, skb = ip_rcv_core(skb, net); if (skb == NULL) return NET_RX_DROP; + return NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, net, NULL, skb, dev, NULL, ip_rcv_finish); @@ -545,6 +545,7 @@ static void ip_list_rcv_finish(struct net *net, struct sock *sk, INIT_LIST_HEAD(&sublist); list_for_each_entry_safe(skb, next, head, list) { + struct net_device *dev = skb->dev; struct dst_entry *dst; skb_list_del_init(skb); @@ -554,7 +555,7 @@ static void ip_list_rcv_finish(struct net *net, struct sock *sk, skb = l3mdev_ip_rcv(skb); if (!skb) continue; - if (ip_rcv_finish_core(net, sk, skb) == NET_RX_DROP) + if (ip_rcv_finish_core(net, sk, skb, dev) == NET_RX_DROP) continue; dst = skb_dst(skb); diff --git a/net/ipv4/ip_options.c b/net/ipv4/ip_options.c index ed194d46c00e..32a35043c9f5 100644 --- a/net/ipv4/ip_options.c +++ b/net/ipv4/ip_options.c @@ -251,8 +251,9 @@ static void spec_dst_fill(__be32 *spec_dst, struct sk_buff *skb) * If opt == NULL, then skb->data should point to IP header. */ -int ip_options_compile(struct net *net, - struct ip_options *opt, struct sk_buff *skb) +int __ip_options_compile(struct net *net, + struct ip_options *opt, struct sk_buff *skb, + __be32 *info) { __be32 spec_dst = htonl(INADDR_ANY); unsigned char *pp_ptr = NULL; @@ -468,11 +469,22 @@ eol: return 0; error: - if (skb) { - icmp_send(skb, ICMP_PARAMETERPROB, 0, htonl((pp_ptr-iph)<<24)); - } + if (info) + *info = htonl((pp_ptr-iph)<<24); return -EINVAL; } + +int ip_options_compile(struct net *net, + struct ip_options *opt, struct sk_buff *skb) +{ + int ret; + __be32 info; + + ret = __ip_options_compile(net, opt, skb, &info); + if (ret != 0 && skb) + icmp_send(skb, ICMP_PARAMETERPROB, 0, info); + return ret; +} EXPORT_SYMBOL(ip_options_compile); /* diff --git a/net/ipv4/ip_tunnel.c b/net/ipv4/ip_tunnel.c index c4f5602308ed..a5d8cad18ead 100644 --- a/net/ipv4/ip_tunnel.c +++ b/net/ipv4/ip_tunnel.c @@ -310,7 +310,7 @@ static int ip_tunnel_bind_dev(struct net_device *dev) ip_tunnel_init_flow(&fl4, iph->protocol, iph->daddr, iph->saddr, tunnel->parms.o_key, RT_TOS(iph->tos), tunnel->parms.link, - tunnel->fwmark); + tunnel->fwmark, 0); rt = ip_route_output_key(tunnel->net, &fl4); if (!IS_ERR(rt)) { @@ -501,19 +501,24 @@ EXPORT_SYMBOL_GPL(ip_tunnel_encap_setup); static int tnl_update_pmtu(struct net_device *dev, struct sk_buff *skb, struct rtable *rt, __be16 df, - const struct iphdr *inner_iph) + const struct iphdr *inner_iph, + int tunnel_hlen, __be32 dst, bool md) { struct ip_tunnel *tunnel = netdev_priv(dev); - int pkt_size = skb->len - tunnel->hlen - dev->hard_header_len; + int pkt_size; int mtu; + tunnel_hlen = md ? tunnel_hlen : tunnel->hlen; + pkt_size = skb->len - tunnel_hlen - dev->hard_header_len; + if (df) mtu = dst_mtu(&rt->dst) - dev->hard_header_len - - sizeof(struct iphdr) - tunnel->hlen; + - sizeof(struct iphdr) - tunnel_hlen; else - mtu = skb_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; + mtu = skb_valid_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; - skb_dst_update_pmtu(skb, mtu); + if (skb_valid_dst(skb)) + skb_dst_update_pmtu(skb, mtu); if (skb->protocol == htons(ETH_P_IP)) { if (!skb_is_gso(skb) && @@ -526,12 +531,16 @@ static int tnl_update_pmtu(struct net_device *dev, struct sk_buff *skb, } #if IS_ENABLED(CONFIG_IPV6) else if (skb->protocol == htons(ETH_P_IPV6)) { - struct rt6_info *rt6 = (struct rt6_info *)skb_dst(skb); + struct rt6_info *rt6; + __be32 daddr; + + rt6 = skb_valid_dst(skb) ? (struct rt6_info *)skb_dst(skb) : + NULL; + daddr = md ? dst : tunnel->parms.iph.daddr; if (rt6 && mtu < dst_mtu(skb_dst(skb)) && mtu >= IPV6_MIN_MTU) { - if ((tunnel->parms.iph.daddr && - !ipv4_is_multicast(tunnel->parms.iph.daddr)) || + if ((daddr && !ipv4_is_multicast(daddr)) || rt6->rt6i_dst.plen == 128) { rt6->rt6i_flags |= RTF_MODIFIED; dst_metric_set(skb_dst(skb), RTAX_MTU, mtu); @@ -548,17 +557,19 @@ static int tnl_update_pmtu(struct net_device *dev, struct sk_buff *skb, return 0; } -void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, u8 proto) +void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, + u8 proto, int tunnel_hlen) { struct ip_tunnel *tunnel = netdev_priv(dev); u32 headroom = sizeof(struct iphdr); struct ip_tunnel_info *tun_info; const struct ip_tunnel_key *key; const struct iphdr *inner_iph; - struct rtable *rt; + struct rtable *rt = NULL; struct flowi4 fl4; __be16 df = 0; u8 tos, ttl; + bool use_cache; tun_info = skb_tunnel_info(skb); if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || @@ -574,20 +585,39 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, u8 proto) else if (skb->protocol == htons(ETH_P_IPV6)) tos = ipv6_get_dsfield((const struct ipv6hdr *)inner_iph); } - ip_tunnel_init_flow(&fl4, proto, key->u.ipv4.dst, key->u.ipv4.src, 0, - RT_TOS(tos), tunnel->parms.link, tunnel->fwmark); + ip_tunnel_init_flow(&fl4, proto, key->u.ipv4.dst, key->u.ipv4.src, + tunnel_id_to_key32(key->tun_id), RT_TOS(tos), + 0, skb->mark, skb_get_hash(skb)); if (tunnel->encap.type != TUNNEL_ENCAP_NONE) goto tx_error; - rt = ip_route_output_key(tunnel->net, &fl4); - if (IS_ERR(rt)) { - dev->stats.tx_carrier_errors++; - goto tx_error; + + use_cache = ip_tunnel_dst_cache_usable(skb, tun_info); + if (use_cache) + rt = dst_cache_get_ip4(&tun_info->dst_cache, &fl4.saddr); + if (!rt) { + rt = ip_route_output_key(tunnel->net, &fl4); + if (IS_ERR(rt)) { + dev->stats.tx_carrier_errors++; + goto tx_error; + } + if (use_cache) + dst_cache_set_ip4(&tun_info->dst_cache, &rt->dst, + fl4.saddr); } if (rt->dst.dev == dev) { ip_rt_put(rt); dev->stats.collisions++; goto tx_error; } + + if (key->tun_flags & TUNNEL_DONT_FRAGMENT) + df = htons(IP_DF); + if (tnl_update_pmtu(dev, skb, rt, df, inner_iph, tunnel_hlen, + key->u.ipv4.dst, true)) { + ip_rt_put(rt); + goto tx_error; + } + tos = ip_tunnel_ecn_encap(tos, inner_iph, skb); ttl = key->ttl; if (ttl == 0) { @@ -598,10 +628,10 @@ void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, u8 proto) else ttl = ip4_dst_hoplimit(&rt->dst); } - if (key->tun_flags & TUNNEL_DONT_FRAGMENT) - df = htons(IP_DF); - else if (skb->protocol == htons(ETH_P_IP)) + + if (!df && skb->protocol == htons(ETH_P_IP)) df = inner_iph->frag_off & htons(IP_DF); + headroom += LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len; if (headroom > dev->needed_headroom) dev->needed_headroom = headroom; @@ -627,14 +657,17 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, const struct iphdr *tnl_params, u8 protocol) { struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_info *tun_info = NULL; const struct iphdr *inner_iph; - struct flowi4 fl4; - u8 tos, ttl; - __be16 df; - struct rtable *rt; /* Route to the other host */ unsigned int max_headroom; /* The extra header space needed */ - __be32 dst; + struct rtable *rt = NULL; /* Route to the other host */ + bool use_cache = false; + struct flowi4 fl4; + bool md = false; bool connected; + u8 tos, ttl; + __be32 dst; + __be16 df; inner_iph = (const struct iphdr *)skb_inner_network_header(skb); connected = (tunnel->parms.iph.daddr != 0); @@ -650,7 +683,15 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, goto tx_error; } - if (skb->protocol == htons(ETH_P_IP)) { + tun_info = skb_tunnel_info(skb); + if (tun_info && (tun_info->mode & IP_TUNNEL_INFO_TX) && + ip_tunnel_info_af(tun_info) == AF_INET && + tun_info->key.u.ipv4.dst) { + dst = tun_info->key.u.ipv4.dst; + md = true; + connected = true; + } + else if (skb->protocol == htons(ETH_P_IP)) { rt = skb_rtable(skb); dst = rt_nexthop(rt, inner_iph->daddr); } @@ -688,7 +729,8 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, else goto tx_error; - connected = false; + if (!md) + connected = false; } tos = tnl_params->tos; @@ -705,13 +747,20 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, ip_tunnel_init_flow(&fl4, protocol, dst, tnl_params->saddr, tunnel->parms.o_key, RT_TOS(tos), tunnel->parms.link, - tunnel->fwmark); + tunnel->fwmark, skb_get_hash(skb)); if (ip_tunnel_encap(skb, tunnel, &protocol, &fl4) < 0) goto tx_error; - rt = connected ? dst_cache_get_ip4(&tunnel->dst_cache, &fl4.saddr) : - NULL; + if (connected && md) { + use_cache = ip_tunnel_dst_cache_usable(skb, tun_info); + if (use_cache) + rt = dst_cache_get_ip4(&tun_info->dst_cache, + &fl4.saddr); + } else { + rt = connected ? dst_cache_get_ip4(&tunnel->dst_cache, + &fl4.saddr) : NULL; + } if (!rt) { rt = ip_route_output_key(tunnel->net, &fl4); @@ -720,7 +769,10 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, dev->stats.tx_carrier_errors++; goto tx_error; } - if (connected) + if (use_cache) + dst_cache_set_ip4(&tun_info->dst_cache, &rt->dst, + fl4.saddr); + else if (!md && connected) dst_cache_set_ip4(&tunnel->dst_cache, &rt->dst, fl4.saddr); } @@ -731,7 +783,8 @@ void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, goto tx_error; } - if (tnl_update_pmtu(dev, skb, rt, tnl_params->frag_off, inner_iph)) { + if (tnl_update_pmtu(dev, skb, rt, tnl_params->frag_off, inner_iph, + 0, 0, false)) { ip_rt_put(rt); goto tx_error; } diff --git a/net/ipv4/ip_tunnel_core.c b/net/ipv4/ip_tunnel_core.c index 9a0e67b52a4e..c3f3d28d1087 100644 --- a/net/ipv4/ip_tunnel_core.c +++ b/net/ipv4/ip_tunnel_core.c @@ -252,6 +252,14 @@ static int ip_tun_build_state(struct nlattr *attr, tun_info = lwt_tun_info(new_state); +#ifdef CONFIG_DST_CACHE + err = dst_cache_init(&tun_info->dst_cache, GFP_KERNEL); + if (err) { + lwtstate_free(new_state); + return err; + } +#endif + if (tb[LWTUNNEL_IP_ID]) tun_info->key.tun_id = nla_get_be64(tb[LWTUNNEL_IP_ID]); @@ -278,6 +286,15 @@ static int ip_tun_build_state(struct nlattr *attr, return 0; } +static void ip_tun_destroy_state(struct lwtunnel_state *lwtstate) +{ +#ifdef CONFIG_DST_CACHE + struct ip_tunnel_info *tun_info = lwt_tun_info(lwtstate); + + dst_cache_destroy(&tun_info->dst_cache); +#endif +} + static int ip_tun_fill_encap_info(struct sk_buff *skb, struct lwtunnel_state *lwtstate) { @@ -313,6 +330,7 @@ static int ip_tun_cmp_encap(struct lwtunnel_state *a, struct lwtunnel_state *b) static const struct lwtunnel_encap_ops ip_tun_lwt_ops = { .build_state = ip_tun_build_state, + .destroy_state = ip_tun_destroy_state, .fill_encap = ip_tun_fill_encap_info, .get_encap_size = ip_tun_encap_nlsize, .cmp_encap = ip_tun_cmp_encap, diff --git a/net/ipv4/ip_vti.c b/net/ipv4/ip_vti.c index d7b43e700023..68a21bf75dd0 100644 --- a/net/ipv4/ip_vti.c +++ b/net/ipv4/ip_vti.c @@ -74,6 +74,33 @@ drop: return 0; } +static int vti_input_ipip(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type) +{ + struct ip_tunnel *tunnel; + const struct iphdr *iph = ip_hdr(skb); + struct net *net = dev_net(skb->dev); + struct ip_tunnel_net *itn = net_generic(net, vti_net_id); + + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + iph->saddr, iph->daddr, 0); + if (tunnel) { + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto drop; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = tunnel; + + skb->dev = tunnel->dev; + + return xfrm_input(skb, nexthdr, spi, encap_type); + } + + return -EINVAL; +drop: + kfree_skb(skb); + return 0; +} + static int vti_rcv(struct sk_buff *skb) { XFRM_SPI_SKB_CB(skb)->family = AF_INET; @@ -82,6 +109,14 @@ static int vti_rcv(struct sk_buff *skb) return vti_input(skb, ip_hdr(skb)->protocol, 0, 0); } +static int vti_rcv_ipip(struct sk_buff *skb) +{ + XFRM_SPI_SKB_CB(skb)->family = AF_INET; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); + + return vti_input_ipip(skb, ip_hdr(skb)->protocol, ip_hdr(skb)->saddr, 0); +} + static int vti_rcv_cb(struct sk_buff *skb, int err) { unsigned short family; @@ -435,6 +470,12 @@ static struct xfrm4_protocol vti_ipcomp4_protocol __read_mostly = { .priority = 100, }; +static struct xfrm_tunnel ipip_handler __read_mostly = { + .handler = vti_rcv_ipip, + .err_handler = vti4_err, + .priority = 0, +}; + static int __net_init vti_init_net(struct net *net) { int err; @@ -603,6 +644,13 @@ static int __init vti_init(void) if (err < 0) goto xfrm_proto_comp_failed; + msg = "ipip tunnel"; + err = xfrm4_tunnel_register(&ipip_handler, AF_INET); + if (err < 0) { + pr_info("%s: cant't register tunnel\n",__func__); + goto xfrm_tunnel_failed; + } + msg = "netlink interface"; err = rtnl_link_register(&vti_link_ops); if (err < 0) @@ -612,6 +660,8 @@ static int __init vti_init(void) rtnl_link_failed: xfrm4_protocol_deregister(&vti_ipcomp4_protocol, IPPROTO_COMP); +xfrm_tunnel_failed: + xfrm4_tunnel_deregister(&ipip_handler, AF_INET); xfrm_proto_comp_failed: xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH); xfrm_proto_ah_failed: diff --git a/net/ipv4/ipconfig.c b/net/ipv4/ipconfig.c index b9a9873c25c6..9bcca08efec9 100644 --- a/net/ipv4/ipconfig.c +++ b/net/ipv4/ipconfig.c @@ -85,7 +85,6 @@ /* Define the friendly delay before and after opening net devices */ #define CONF_POST_OPEN 10 /* After opening: 10 msecs */ -#define CONF_CARRIER_TIMEOUT 120000 /* Wait for carrier timeout */ /* Define the timeout for waiting for a DHCP/BOOTP/RARP reply */ #define CONF_OPEN_RETRIES 2 /* (Re)open devices twice */ @@ -101,6 +100,9 @@ #define NONE cpu_to_be32(INADDR_NONE) #define ANY cpu_to_be32(INADDR_ANY) +/* Wait for carrier timeout default in seconds */ +static unsigned int carrier_timeout = 120; + /* * Public IP configuration */ @@ -268,9 +270,9 @@ static int __init ic_open_devs(void) /* wait for a carrier on at least one device */ start = jiffies; - next_msg = start + msecs_to_jiffies(CONF_CARRIER_TIMEOUT/12); + next_msg = start + msecs_to_jiffies(20000); while (time_before(jiffies, start + - msecs_to_jiffies(CONF_CARRIER_TIMEOUT))) { + msecs_to_jiffies(carrier_timeout * 1000))) { int wait, elapsed; for_each_netdev(&init_net, dev) @@ -283,9 +285,9 @@ static int __init ic_open_devs(void) continue; elapsed = jiffies_to_msecs(jiffies - start); - wait = (CONF_CARRIER_TIMEOUT - elapsed + 500)/1000; + wait = (carrier_timeout * 1000 - elapsed + 500) / 1000; pr_info("Waiting up to %d more seconds for network.\n", wait); - next_msg = jiffies + msecs_to_jiffies(CONF_CARRIER_TIMEOUT/12); + next_msg = jiffies + msecs_to_jiffies(20000); } have_carrier: rtnl_unlock(); @@ -1780,3 +1782,18 @@ static int __init vendor_class_identifier_setup(char *addrs) return 1; } __setup("dhcpclass=", vendor_class_identifier_setup); + +static int __init set_carrier_timeout(char *str) +{ + ssize_t ret; + + if (!str) + return 0; + + ret = kstrtouint(str, 0, &carrier_timeout); + if (ret) + return 0; + + return 1; +} +__setup("carrier_timeout=", set_carrier_timeout); diff --git a/net/ipv4/ipip.c b/net/ipv4/ipip.c index 57c5dd283a2c..fe10b9a2efc8 100644 --- a/net/ipv4/ipip.c +++ b/net/ipv4/ipip.c @@ -302,7 +302,7 @@ static netdev_tx_t ipip_tunnel_xmit(struct sk_buff *skb, skb_set_inner_ipproto(skb, ipproto); if (tunnel->collect_md) - ip_md_tunnel_xmit(skb, dev, ipproto); + ip_md_tunnel_xmit(skb, dev, ipproto, 0); else ip_tunnel_xmit(skb, dev, tiph, ipproto); return NETDEV_TX_OK; diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index ddbf8c9a1abb..2c931120c494 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -67,7 +67,6 @@ #include <net/fib_rules.h> #include <linux/netconf.h> #include <net/nexthop.h> -#include <net/switchdev.h> #include <linux/nospec.h> @@ -111,7 +110,7 @@ static int ipmr_cache_report(struct mr_table *mrt, static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc, int cmd); static void igmpmsg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt); -static void mroute_clean_tables(struct mr_table *mrt, bool all); +static void mroute_clean_tables(struct mr_table *mrt, int flags); static void ipmr_expire_process(struct timer_list *t); #ifdef CONFIG_IP_MROUTE_MULTIPLE_TABLES @@ -416,7 +415,8 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id) static void ipmr_free_table(struct mr_table *mrt) { del_timer_sync(&mrt->ipmr_expire_timer); - mroute_clean_tables(mrt, true); + mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC | + MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC); rhltable_destroy(&mrt->mfc_hash); kfree(mrt); } @@ -837,10 +837,8 @@ static void ipmr_update_thresholds(struct mr_table *mrt, struct mr_mfc *cache, static int vif_add(struct net *net, struct mr_table *mrt, struct vifctl *vifc, int mrtsock) { + struct netdev_phys_item_id ppid = { }; int vifi = vifc->vifc_vifi; - struct switchdev_attr attr = { - .id = SWITCHDEV_ATTR_ID_PORT_PARENT_ID, - }; struct vif_device *v = &mrt->vif_table[vifi]; struct net_device *dev; struct in_device *in_dev; @@ -919,10 +917,10 @@ static int vif_add(struct net *net, struct mr_table *mrt, vifc->vifc_flags | (!mrtsock ? VIFF_STATIC : 0), (VIFF_TUNNEL | VIFF_REGISTER)); - attr.orig_dev = dev; - if (!switchdev_port_attr_get(dev, &attr)) { - memcpy(v->dev_parent_id.id, attr.u.ppid.id, attr.u.ppid.id_len); - v->dev_parent_id.id_len = attr.u.ppid.id_len; + err = dev_get_port_parent_id(dev, &ppid, true); + if (err == 0) { + memcpy(v->dev_parent_id.id, ppid.id, ppid.id_len); + v->dev_parent_id.id_len = ppid.id_len; } else { v->dev_parent_id.id_len = 0; } @@ -1299,7 +1297,7 @@ static int ipmr_mfc_add(struct net *net, struct mr_table *mrt, } /* Close the multicast socket, and clear the vif tables etc */ -static void mroute_clean_tables(struct mr_table *mrt, bool all) +static void mroute_clean_tables(struct mr_table *mrt, int flags) { struct net *net = read_pnet(&mrt->net); struct mr_mfc *c, *tmp; @@ -1308,35 +1306,44 @@ static void mroute_clean_tables(struct mr_table *mrt, bool all) int i; /* Shut down all active vif entries */ - for (i = 0; i < mrt->maxvif; i++) { - if (!all && (mrt->vif_table[i].flags & VIFF_STATIC)) - continue; - vif_delete(mrt, i, 0, &list); + if (flags & (MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC)) { + for (i = 0; i < mrt->maxvif; i++) { + if (((mrt->vif_table[i].flags & VIFF_STATIC) && + !(flags & MRT_FLUSH_VIFS_STATIC)) || + (!(mrt->vif_table[i].flags & VIFF_STATIC) && !(flags & MRT_FLUSH_VIFS))) + continue; + vif_delete(mrt, i, 0, &list); + } + unregister_netdevice_many(&list); } - unregister_netdevice_many(&list); /* Wipe the cache */ - list_for_each_entry_safe(c, tmp, &mrt->mfc_cache_list, list) { - if (!all && (c->mfc_flags & MFC_STATIC)) - continue; - rhltable_remove(&mrt->mfc_hash, &c->mnode, ipmr_rht_params); - list_del_rcu(&c->list); - cache = (struct mfc_cache *)c; - call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_DEL, cache, - mrt->id); - mroute_netlink_event(mrt, cache, RTM_DELROUTE); - mr_cache_put(c); - } - - if (atomic_read(&mrt->cache_resolve_queue_len) != 0) { - spin_lock_bh(&mfc_unres_lock); - list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) { - list_del(&c->list); + if (flags & (MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC)) { + list_for_each_entry_safe(c, tmp, &mrt->mfc_cache_list, list) { + if (((c->mfc_flags & MFC_STATIC) && !(flags & MRT_FLUSH_MFC_STATIC)) || + (!(c->mfc_flags & MFC_STATIC) && !(flags & MRT_FLUSH_MFC))) + continue; + rhltable_remove(&mrt->mfc_hash, &c->mnode, ipmr_rht_params); + list_del_rcu(&c->list); cache = (struct mfc_cache *)c; + call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_DEL, cache, + mrt->id); mroute_netlink_event(mrt, cache, RTM_DELROUTE); - ipmr_destroy_unres(mrt, cache); + mr_cache_put(c); + } + } + + if (flags & MRT_FLUSH_MFC) { + if (atomic_read(&mrt->cache_resolve_queue_len) != 0) { + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) { + list_del(&c->list); + cache = (struct mfc_cache *)c; + mroute_netlink_event(mrt, cache, RTM_DELROUTE); + ipmr_destroy_unres(mrt, cache); + } + spin_unlock_bh(&mfc_unres_lock); } - spin_unlock_bh(&mfc_unres_lock); } } @@ -1357,7 +1364,7 @@ static void mrtsock_destruct(struct sock *sk) NETCONFA_IFINDEX_ALL, net->ipv4.devconf_all); RCU_INIT_POINTER(mrt->mroute_sk, NULL); - mroute_clean_tables(mrt, false); + mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_MFC); } } rtnl_unlock(); @@ -1482,6 +1489,17 @@ int ip_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, sk == rtnl_dereference(mrt->mroute_sk), parent); break; + case MRT_FLUSH: + if (optlen != sizeof(val)) { + ret = -EINVAL; + break; + } + if (get_user(val, (int __user *)optval)) { + ret = -EFAULT; + break; + } + mroute_clean_tables(mrt, val); + break; /* Control PIM assert. */ case MRT_ASSERT: if (optlen != sizeof(val)) { @@ -2467,6 +2485,61 @@ errout: rtnl_set_sk_err(net, RTNLGRP_IPV4_MROUTE_R, -ENOBUFS); } +static int ipmr_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for multicast route get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 32) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 32) || + rtm->rtm_tos || rtm->rtm_table || rtm->rtm_protocol || + rtm->rtm_scope || rtm->rtm_type || rtm->rtm_flags) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for multicast route get request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + if (err) + return err; + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG(extack, "ipv4: rtm_src_len and rtm_dst_len must be 32 for IPv4"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_SRC: + case RTA_DST: + case RTA_TABLE: + break; + default: + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in multicast route get request"); + return -EINVAL; + } + } + + return 0; +} + static int ipmr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -2475,18 +2548,14 @@ static int ipmr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct sk_buff *skb = NULL; struct mfc_cache *cache; struct mr_table *mrt; - struct rtmsg *rtm; __be32 src, grp; u32 tableid; int err; - err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, - rtm_ipv4_policy, extack); + err = ipmr_rtm_valid_getroute_req(in_skb, nlh, tb, extack); if (err < 0) goto errout; - rtm = nlmsg_data(nlh); - src = tb[RTA_SRC] ? nla_get_in_addr(tb[RTA_SRC]) : 0; grp = tb[RTA_DST] ? nla_get_in_addr(tb[RTA_DST]) : 0; tableid = tb[RTA_TABLE] ? nla_get_u32(tb[RTA_TABLE]) : 0; diff --git a/net/ipv4/netfilter.c b/net/ipv4/netfilter.c index 8d2e5dc9a827..a058213b77a7 100644 --- a/net/ipv4/netfilter.c +++ b/net/ipv4/netfilter.c @@ -80,24 +80,6 @@ int ip_route_me_harder(struct net *net, struct sk_buff *skb, unsigned int addr_t } EXPORT_SYMBOL(ip_route_me_harder); -int nf_ip_reroute(struct sk_buff *skb, const struct nf_queue_entry *entry) -{ - const struct ip_rt_info *rt_info = nf_queue_entry_reroute(entry); - - if (entry->state.hook == NF_INET_LOCAL_OUT) { - const struct iphdr *iph = ip_hdr(skb); - - if (!(iph->tos == rt_info->tos && - skb->mark == rt_info->mark && - iph->daddr == rt_info->daddr && - iph->saddr == rt_info->saddr)) - return ip_route_me_harder(entry->state.net, skb, - RTN_UNSPEC); - } - return 0; -} -EXPORT_SYMBOL_GPL(nf_ip_reroute); - int nf_ip_route(struct net *net, struct dst_entry **dst, struct flowi *fl, bool strict __always_unused) { diff --git a/net/ipv4/netfilter/Kconfig b/net/ipv4/netfilter/Kconfig index 80f72cc5ca8d..c98391d49200 100644 --- a/net/ipv4/netfilter/Kconfig +++ b/net/ipv4/netfilter/Kconfig @@ -94,50 +94,7 @@ config NF_REJECT_IPV4 tristate "IPv4 packet rejection" default m if NETFILTER_ADVANCED=n -config NF_NAT_IPV4 - tristate "IPv4 NAT" - depends on NF_CONNTRACK - default m if NETFILTER_ADVANCED=n - select NF_NAT - help - The IPv4 NAT option allows masquerading, port forwarding and other - forms of full Network Address Port Translation. This can be - controlled by iptables or nft. - -if NF_NAT_IPV4 - -config NF_NAT_MASQUERADE_IPV4 - bool - -if NF_TABLES -config NFT_CHAIN_NAT_IPV4 - depends on NF_TABLES_IPV4 - tristate "IPv4 nf_tables nat chain support" - help - This option enables the "nat" chain for IPv4 in nf_tables. This - chain type is used to perform Network Address Translation (NAT) - packet transformations such as the source, destination address and - source and destination ports. - -config NFT_MASQ_IPV4 - tristate "IPv4 masquerading support for nf_tables" - depends on NF_TABLES_IPV4 - depends on NFT_MASQ - select NF_NAT_MASQUERADE_IPV4 - help - This is the expression that provides IPv4 masquerading support for - nf_tables. - -config NFT_REDIR_IPV4 - tristate "IPv4 redirect support for nf_tables" - depends on NF_TABLES_IPV4 - depends on NFT_REDIR - select NF_NAT_REDIRECT - help - This is the expression that provides IPv4 redirect support for - nf_tables. -endif # NF_TABLES - +if NF_NAT config NF_NAT_SNMP_BASIC tristate "Basic SNMP-ALG support" depends on NF_CONNTRACK_SNMP @@ -166,7 +123,7 @@ config NF_NAT_H323 depends on NF_CONNTRACK default NF_CONNTRACK_H323 -endif # NF_NAT_IPV4 +endif # NF_NAT config IP_NF_IPTABLES tristate "IP tables support (required for filtering/masq/NAT)" @@ -263,7 +220,6 @@ config IP_NF_NAT depends on NF_CONNTRACK default m if NETFILTER_ADVANCED=n select NF_NAT - select NF_NAT_IPV4 select NETFILTER_XT_NAT help This enables the `nat' table in iptables. This allows masquerading, @@ -276,7 +232,7 @@ if IP_NF_NAT config IP_NF_TARGET_MASQUERADE tristate "MASQUERADE target support" - select NF_NAT_MASQUERADE_IPV4 + select NF_NAT_MASQUERADE default m if NETFILTER_ADVANCED=n help Masquerading is a special case of NAT: all outgoing connections are diff --git a/net/ipv4/netfilter/Makefile b/net/ipv4/netfilter/Makefile index fd7122e0e2c9..e241f5188ebe 100644 --- a/net/ipv4/netfilter/Makefile +++ b/net/ipv4/netfilter/Makefile @@ -3,10 +3,6 @@ # Makefile for the netfilter modules on top of IPv4. # -nf_nat_ipv4-y := nf_nat_l3proto_ipv4.o -nf_nat_ipv4-$(CONFIG_NF_NAT_MASQUERADE_IPV4) += nf_nat_masquerade_ipv4.o -obj-$(CONFIG_NF_NAT_IPV4) += nf_nat_ipv4.o - # defrag obj-$(CONFIG_NF_DEFRAG_IPV4) += nf_defrag_ipv4.o @@ -29,11 +25,8 @@ $(obj)/nf_nat_snmp_basic_main.o: $(obj)/nf_nat_snmp_basic.asn1.h obj-$(CONFIG_NF_NAT_SNMP_BASIC) += nf_nat_snmp_basic.o obj-$(CONFIG_NFT_CHAIN_ROUTE_IPV4) += nft_chain_route_ipv4.o -obj-$(CONFIG_NFT_CHAIN_NAT_IPV4) += nft_chain_nat_ipv4.o obj-$(CONFIG_NFT_REJECT_IPV4) += nft_reject_ipv4.o obj-$(CONFIG_NFT_FIB_IPV4) += nft_fib_ipv4.o -obj-$(CONFIG_NFT_MASQ_IPV4) += nft_masq_ipv4.o -obj-$(CONFIG_NFT_REDIR_IPV4) += nft_redir_ipv4.o obj-$(CONFIG_NFT_DUP_IPV4) += nft_dup_ipv4.o # flow table support diff --git a/net/ipv4/netfilter/ipt_CLUSTERIP.c b/net/ipv4/netfilter/ipt_CLUSTERIP.c index b61977db9b7f..835d50b279f5 100644 --- a/net/ipv4/netfilter/ipt_CLUSTERIP.c +++ b/net/ipv4/netfilter/ipt_CLUSTERIP.c @@ -846,9 +846,9 @@ static int clusterip_net_init(struct net *net) static void clusterip_net_exit(struct net *net) { +#ifdef CONFIG_PROC_FS struct clusterip_net *cn = clusterip_pernet(net); -#ifdef CONFIG_PROC_FS mutex_lock(&cn->mutex); proc_remove(cn->procdir); cn->procdir = NULL; @@ -864,7 +864,7 @@ static struct pernet_operations clusterip_net_ops = { .size = sizeof(struct clusterip_net), }; -struct notifier_block cip_netdev_notifier = { +static struct notifier_block cip_netdev_notifier = { .notifier_call = clusterip_netdev_event }; diff --git a/net/ipv4/netfilter/iptable_nat.c b/net/ipv4/netfilter/iptable_nat.c index a317445448bf..007da0882412 100644 --- a/net/ipv4/netfilter/iptable_nat.c +++ b/net/ipv4/netfilter/iptable_nat.c @@ -15,8 +15,6 @@ #include <net/ip.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_nat_l3proto.h> static int __net_init iptable_nat_table_init(struct net *net); @@ -70,10 +68,10 @@ static int ipt_nat_register_lookups(struct net *net) int i, ret; for (i = 0; i < ARRAY_SIZE(nf_nat_ipv4_ops); i++) { - ret = nf_nat_l3proto_ipv4_register_fn(net, &nf_nat_ipv4_ops[i]); + ret = nf_nat_ipv4_register_fn(net, &nf_nat_ipv4_ops[i]); if (ret) { while (i) - nf_nat_l3proto_ipv4_unregister_fn(net, &nf_nat_ipv4_ops[--i]); + nf_nat_ipv4_unregister_fn(net, &nf_nat_ipv4_ops[--i]); return ret; } @@ -87,7 +85,7 @@ static void ipt_nat_unregister_lookups(struct net *net) int i; for (i = 0; i < ARRAY_SIZE(nf_nat_ipv4_ops); i++) - nf_nat_l3proto_ipv4_unregister_fn(net, &nf_nat_ipv4_ops[i]); + nf_nat_ipv4_unregister_fn(net, &nf_nat_ipv4_ops[i]); } static int __net_init iptable_nat_table_init(struct net *net) diff --git a/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c b/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c deleted file mode 100644 index 2687db015b6f..000000000000 --- a/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c +++ /dev/null @@ -1,387 +0,0 @@ -/* - * (C) 1999-2001 Paul `Rusty' Russell - * (C) 2002-2006 Netfilter Core Team <coreteam@netfilter.org> - * (C) 2011 Patrick McHardy <kaber@trash.net> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - */ - -#include <linux/types.h> -#include <linux/module.h> -#include <linux/skbuff.h> -#include <linux/ip.h> -#include <linux/icmp.h> -#include <linux/netfilter.h> -#include <linux/netfilter_ipv4.h> -#include <net/secure_seq.h> -#include <net/checksum.h> -#include <net/route.h> -#include <net/ip.h> - -#include <net/netfilter/nf_conntrack_core.h> -#include <net/netfilter/nf_conntrack.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/netfilter/nf_nat_l4proto.h> - -static const struct nf_nat_l3proto nf_nat_l3proto_ipv4; - -#ifdef CONFIG_XFRM -static void nf_nat_ipv4_decode_session(struct sk_buff *skb, - const struct nf_conn *ct, - enum ip_conntrack_dir dir, - unsigned long statusbit, - struct flowi *fl) -{ - const struct nf_conntrack_tuple *t = &ct->tuplehash[dir].tuple; - struct flowi4 *fl4 = &fl->u.ip4; - - if (ct->status & statusbit) { - fl4->daddr = t->dst.u3.ip; - if (t->dst.protonum == IPPROTO_TCP || - t->dst.protonum == IPPROTO_UDP || - t->dst.protonum == IPPROTO_UDPLITE || - t->dst.protonum == IPPROTO_DCCP || - t->dst.protonum == IPPROTO_SCTP) - fl4->fl4_dport = t->dst.u.all; - } - - statusbit ^= IPS_NAT_MASK; - - if (ct->status & statusbit) { - fl4->saddr = t->src.u3.ip; - if (t->dst.protonum == IPPROTO_TCP || - t->dst.protonum == IPPROTO_UDP || - t->dst.protonum == IPPROTO_UDPLITE || - t->dst.protonum == IPPROTO_DCCP || - t->dst.protonum == IPPROTO_SCTP) - fl4->fl4_sport = t->src.u.all; - } -} -#endif /* CONFIG_XFRM */ - -static bool nf_nat_ipv4_manip_pkt(struct sk_buff *skb, - unsigned int iphdroff, - const struct nf_conntrack_tuple *target, - enum nf_nat_manip_type maniptype) -{ - struct iphdr *iph; - unsigned int hdroff; - - if (!skb_make_writable(skb, iphdroff + sizeof(*iph))) - return false; - - iph = (void *)skb->data + iphdroff; - hdroff = iphdroff + iph->ihl * 4; - - if (!nf_nat_l4proto_manip_pkt(skb, &nf_nat_l3proto_ipv4, iphdroff, - hdroff, target, maniptype)) - return false; - iph = (void *)skb->data + iphdroff; - - if (maniptype == NF_NAT_MANIP_SRC) { - csum_replace4(&iph->check, iph->saddr, target->src.u3.ip); - iph->saddr = target->src.u3.ip; - } else { - csum_replace4(&iph->check, iph->daddr, target->dst.u3.ip); - iph->daddr = target->dst.u3.ip; - } - return true; -} - -static void nf_nat_ipv4_csum_update(struct sk_buff *skb, - unsigned int iphdroff, __sum16 *check, - const struct nf_conntrack_tuple *t, - enum nf_nat_manip_type maniptype) -{ - struct iphdr *iph = (struct iphdr *)(skb->data + iphdroff); - __be32 oldip, newip; - - if (maniptype == NF_NAT_MANIP_SRC) { - oldip = iph->saddr; - newip = t->src.u3.ip; - } else { - oldip = iph->daddr; - newip = t->dst.u3.ip; - } - inet_proto_csum_replace4(check, skb, oldip, newip, true); -} - -static void nf_nat_ipv4_csum_recalc(struct sk_buff *skb, - u8 proto, void *data, __sum16 *check, - int datalen, int oldlen) -{ - if (skb->ip_summed != CHECKSUM_PARTIAL) { - const struct iphdr *iph = ip_hdr(skb); - - skb->ip_summed = CHECKSUM_PARTIAL; - skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) + - ip_hdrlen(skb); - skb->csum_offset = (void *)check - data; - *check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, - proto, 0); - } else - inet_proto_csum_replace2(check, skb, - htons(oldlen), htons(datalen), true); -} - -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) -static int nf_nat_ipv4_nlattr_to_range(struct nlattr *tb[], - struct nf_nat_range2 *range) -{ - if (tb[CTA_NAT_V4_MINIP]) { - range->min_addr.ip = nla_get_be32(tb[CTA_NAT_V4_MINIP]); - range->flags |= NF_NAT_RANGE_MAP_IPS; - } - - if (tb[CTA_NAT_V4_MAXIP]) - range->max_addr.ip = nla_get_be32(tb[CTA_NAT_V4_MAXIP]); - else - range->max_addr.ip = range->min_addr.ip; - - return 0; -} -#endif - -static const struct nf_nat_l3proto nf_nat_l3proto_ipv4 = { - .l3proto = NFPROTO_IPV4, - .manip_pkt = nf_nat_ipv4_manip_pkt, - .csum_update = nf_nat_ipv4_csum_update, - .csum_recalc = nf_nat_ipv4_csum_recalc, -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .nlattr_to_range = nf_nat_ipv4_nlattr_to_range, -#endif -#ifdef CONFIG_XFRM - .decode_session = nf_nat_ipv4_decode_session, -#endif -}; - -int nf_nat_icmp_reply_translation(struct sk_buff *skb, - struct nf_conn *ct, - enum ip_conntrack_info ctinfo, - unsigned int hooknum) -{ - struct { - struct icmphdr icmp; - struct iphdr ip; - } *inside; - enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); - enum nf_nat_manip_type manip = HOOK2MANIP(hooknum); - unsigned int hdrlen = ip_hdrlen(skb); - struct nf_conntrack_tuple target; - unsigned long statusbit; - - WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY); - - if (!skb_make_writable(skb, hdrlen + sizeof(*inside))) - return 0; - if (nf_ip_checksum(skb, hooknum, hdrlen, 0)) - return 0; - - inside = (void *)skb->data + hdrlen; - if (inside->icmp.type == ICMP_REDIRECT) { - if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK) - return 0; - if (ct->status & IPS_NAT_MASK) - return 0; - } - - if (manip == NF_NAT_MANIP_SRC) - statusbit = IPS_SRC_NAT; - else - statusbit = IPS_DST_NAT; - - /* Invert if this is reply direction */ - if (dir == IP_CT_DIR_REPLY) - statusbit ^= IPS_NAT_MASK; - - if (!(ct->status & statusbit)) - return 1; - - if (!nf_nat_ipv4_manip_pkt(skb, hdrlen + sizeof(inside->icmp), - &ct->tuplehash[!dir].tuple, !manip)) - return 0; - - if (skb->ip_summed != CHECKSUM_PARTIAL) { - /* Reloading "inside" here since manip_pkt may reallocate */ - inside = (void *)skb->data + hdrlen; - inside->icmp.checksum = 0; - inside->icmp.checksum = - csum_fold(skb_checksum(skb, hdrlen, - skb->len - hdrlen, 0)); - } - - /* Change outer to look like the reply to an incoming packet */ - nf_ct_invert_tuplepr(&target, &ct->tuplehash[!dir].tuple); - if (!nf_nat_ipv4_manip_pkt(skb, 0, &target, manip)) - return 0; - - return 1; -} -EXPORT_SYMBOL_GPL(nf_nat_icmp_reply_translation); - -static unsigned int -nf_nat_ipv4_fn(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - - ct = nf_ct_get(skb, &ctinfo); - if (!ct) - return NF_ACCEPT; - - if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) { - if (ip_hdr(skb)->protocol == IPPROTO_ICMP) { - if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, - state->hook)) - return NF_DROP; - else - return NF_ACCEPT; - } - } - - return nf_nat_inet_fn(priv, skb, state); -} - -static unsigned int -nf_nat_ipv4_in(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - unsigned int ret; - __be32 daddr = ip_hdr(skb)->daddr; - - ret = nf_nat_ipv4_fn(priv, skb, state); - if (ret != NF_DROP && ret != NF_STOLEN && - daddr != ip_hdr(skb)->daddr) - skb_dst_drop(skb); - - return ret; -} - -static unsigned int -nf_nat_ipv4_out(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ -#ifdef CONFIG_XFRM - const struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - int err; -#endif - unsigned int ret; - - ret = nf_nat_ipv4_fn(priv, skb, state); -#ifdef CONFIG_XFRM - if (ret != NF_DROP && ret != NF_STOLEN && - !(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) && - (ct = nf_ct_get(skb, &ctinfo)) != NULL) { - enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); - - if ((ct->tuplehash[dir].tuple.src.u3.ip != - ct->tuplehash[!dir].tuple.dst.u3.ip) || - (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP && - ct->tuplehash[dir].tuple.src.u.all != - ct->tuplehash[!dir].tuple.dst.u.all)) { - err = nf_xfrm_me_harder(state->net, skb, AF_INET); - if (err < 0) - ret = NF_DROP_ERR(err); - } - } -#endif - return ret; -} - -static unsigned int -nf_nat_ipv4_local_fn(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - const struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - unsigned int ret; - int err; - - ret = nf_nat_ipv4_fn(priv, skb, state); - if (ret != NF_DROP && ret != NF_STOLEN && - (ct = nf_ct_get(skb, &ctinfo)) != NULL) { - enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); - - if (ct->tuplehash[dir].tuple.dst.u3.ip != - ct->tuplehash[!dir].tuple.src.u3.ip) { - err = ip_route_me_harder(state->net, skb, RTN_UNSPEC); - if (err < 0) - ret = NF_DROP_ERR(err); - } -#ifdef CONFIG_XFRM - else if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) && - ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP && - ct->tuplehash[dir].tuple.dst.u.all != - ct->tuplehash[!dir].tuple.src.u.all) { - err = nf_xfrm_me_harder(state->net, skb, AF_INET); - if (err < 0) - ret = NF_DROP_ERR(err); - } -#endif - } - return ret; -} - -static const struct nf_hook_ops nf_nat_ipv4_ops[] = { - /* Before packet filtering, change destination */ - { - .hook = nf_nat_ipv4_in, - .pf = NFPROTO_IPV4, - .hooknum = NF_INET_PRE_ROUTING, - .priority = NF_IP_PRI_NAT_DST, - }, - /* After packet filtering, change source */ - { - .hook = nf_nat_ipv4_out, - .pf = NFPROTO_IPV4, - .hooknum = NF_INET_POST_ROUTING, - .priority = NF_IP_PRI_NAT_SRC, - }, - /* Before packet filtering, change destination */ - { - .hook = nf_nat_ipv4_local_fn, - .pf = NFPROTO_IPV4, - .hooknum = NF_INET_LOCAL_OUT, - .priority = NF_IP_PRI_NAT_DST, - }, - /* After packet filtering, change source */ - { - .hook = nf_nat_ipv4_fn, - .pf = NFPROTO_IPV4, - .hooknum = NF_INET_LOCAL_IN, - .priority = NF_IP_PRI_NAT_SRC, - }, -}; - -int nf_nat_l3proto_ipv4_register_fn(struct net *net, const struct nf_hook_ops *ops) -{ - return nf_nat_register_fn(net, ops, nf_nat_ipv4_ops, ARRAY_SIZE(nf_nat_ipv4_ops)); -} -EXPORT_SYMBOL_GPL(nf_nat_l3proto_ipv4_register_fn); - -void nf_nat_l3proto_ipv4_unregister_fn(struct net *net, const struct nf_hook_ops *ops) -{ - nf_nat_unregister_fn(net, ops, ARRAY_SIZE(nf_nat_ipv4_ops)); -} -EXPORT_SYMBOL_GPL(nf_nat_l3proto_ipv4_unregister_fn); - -static int __init nf_nat_l3proto_ipv4_init(void) -{ - return nf_nat_l3proto_register(&nf_nat_l3proto_ipv4); -} - -static void __exit nf_nat_l3proto_ipv4_exit(void) -{ - nf_nat_l3proto_unregister(&nf_nat_l3proto_ipv4); -} - -MODULE_LICENSE("GPL"); -MODULE_ALIAS("nf-nat-" __stringify(AF_INET)); - -module_init(nf_nat_l3proto_ipv4_init); -module_exit(nf_nat_l3proto_ipv4_exit); diff --git a/net/ipv4/netfilter/nf_nat_snmp_basic_main.c b/net/ipv4/netfilter/nf_nat_snmp_basic_main.c index a0aa13bcabda..0a8a60c1bf9a 100644 --- a/net/ipv4/netfilter/nf_nat_snmp_basic_main.c +++ b/net/ipv4/netfilter/nf_nat_snmp_basic_main.c @@ -105,6 +105,8 @@ static void fast_csum(struct snmp_ctx *ctx, unsigned char offset) int snmp_version(void *context, size_t hdrlen, unsigned char tag, const void *data, size_t datalen) { + if (datalen != 1) + return -EINVAL; if (*(unsigned char *)data > 1) return -ENOTSUPP; return 1; @@ -114,8 +116,11 @@ int snmp_helper(void *context, size_t hdrlen, unsigned char tag, const void *data, size_t datalen) { struct snmp_ctx *ctx = (struct snmp_ctx *)context; - __be32 *pdata = (__be32 *)data; + __be32 *pdata; + if (datalen != 4) + return -EINVAL; + pdata = (__be32 *)data; if (*pdata == ctx->from) { pr_debug("%s: %pI4 to %pI4\n", __func__, (void *)&ctx->from, (void *)&ctx->to); diff --git a/net/ipv4/netfilter/nf_reject_ipv4.c b/net/ipv4/netfilter/nf_reject_ipv4.c index aa8304c618b8..7dc3c324b911 100644 --- a/net/ipv4/netfilter/nf_reject_ipv4.c +++ b/net/ipv4/netfilter/nf_reject_ipv4.c @@ -173,21 +173,16 @@ EXPORT_SYMBOL_GPL(nf_send_reset); void nf_send_unreach(struct sk_buff *skb_in, int code, int hook) { struct iphdr *iph = ip_hdr(skb_in); - u8 proto; + u8 proto = iph->protocol; if (iph->frag_off & htons(IP_OFFSET)) return; - if (skb_csum_unnecessary(skb_in)) { + if (skb_csum_unnecessary(skb_in) || !nf_reject_verify_csum(proto)) { icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0); return; } - if (iph->protocol == IPPROTO_TCP || iph->protocol == IPPROTO_UDP) - proto = iph->protocol; - else - proto = 0; - if (nf_ip_checksum(skb_in, hook, ip_hdrlen(skb_in), proto) == 0) icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0); } diff --git a/net/ipv4/netfilter/nft_chain_nat_ipv4.c b/net/ipv4/netfilter/nft_chain_nat_ipv4.c deleted file mode 100644 index a3c4ea303e3e..000000000000 --- a/net/ipv4/netfilter/nft_chain_nat_ipv4.c +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net> - * Copyright (c) 2012 Pablo Neira Ayuso <pablo@netfilter.org> - * Copyright (c) 2012 Intel Corporation - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - * - * Development of this code funded by Astaro AG (http://www.astaro.com/) - */ - -#include <linux/module.h> -#include <linux/init.h> -#include <linux/list.h> -#include <linux/skbuff.h> -#include <linux/ip.h> -#include <linux/netfilter.h> -#include <linux/netfilter_ipv4.h> -#include <linux/netfilter/nf_tables.h> -#include <net/netfilter/nf_conntrack.h> -#include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_tables.h> -#include <net/netfilter/nf_tables_ipv4.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/ip.h> - -static unsigned int nft_nat_do_chain(void *priv, - struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct nft_pktinfo pkt; - - nft_set_pktinfo(&pkt, skb, state); - nft_set_pktinfo_ipv4(&pkt, skb); - - return nft_do_chain(&pkt, priv); -} - -static int nft_nat_ipv4_reg(struct net *net, const struct nf_hook_ops *ops) -{ - return nf_nat_l3proto_ipv4_register_fn(net, ops); -} - -static void nft_nat_ipv4_unreg(struct net *net, const struct nf_hook_ops *ops) -{ - nf_nat_l3proto_ipv4_unregister_fn(net, ops); -} - -static const struct nft_chain_type nft_chain_nat_ipv4 = { - .name = "nat", - .type = NFT_CHAIN_T_NAT, - .family = NFPROTO_IPV4, - .owner = THIS_MODULE, - .hook_mask = (1 << NF_INET_PRE_ROUTING) | - (1 << NF_INET_POST_ROUTING) | - (1 << NF_INET_LOCAL_OUT) | - (1 << NF_INET_LOCAL_IN), - .hooks = { - [NF_INET_PRE_ROUTING] = nft_nat_do_chain, - [NF_INET_POST_ROUTING] = nft_nat_do_chain, - [NF_INET_LOCAL_OUT] = nft_nat_do_chain, - [NF_INET_LOCAL_IN] = nft_nat_do_chain, - }, - .ops_register = nft_nat_ipv4_reg, - .ops_unregister = nft_nat_ipv4_unreg, -}; - -static int __init nft_chain_nat_init(void) -{ - nft_register_chain_type(&nft_chain_nat_ipv4); - - return 0; -} - -static void __exit nft_chain_nat_exit(void) -{ - nft_unregister_chain_type(&nft_chain_nat_ipv4); -} - -module_init(nft_chain_nat_init); -module_exit(nft_chain_nat_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>"); -MODULE_ALIAS_NFT_CHAIN(AF_INET, "nat"); diff --git a/net/ipv4/netfilter/nft_masq_ipv4.c b/net/ipv4/netfilter/nft_masq_ipv4.c deleted file mode 100644 index 6847de1d1db8..000000000000 --- a/net/ipv4/netfilter/nft_masq_ipv4.c +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2014 Arturo Borrero Gonzalez <arturo@debian.org> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - */ - -#include <linux/kernel.h> -#include <linux/init.h> -#include <linux/module.h> -#include <linux/netlink.h> -#include <linux/netfilter.h> -#include <linux/netfilter/nf_tables.h> -#include <net/netfilter/nf_tables.h> -#include <net/netfilter/nft_masq.h> -#include <net/netfilter/ipv4/nf_nat_masquerade.h> - -static void nft_masq_ipv4_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) -{ - struct nft_masq *priv = nft_expr_priv(expr); - struct nf_nat_range2 range; - - memset(&range, 0, sizeof(range)); - range.flags = priv->flags; - if (priv->sreg_proto_min) { - range.min_proto.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_min]); - range.max_proto.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_max]); - } - regs->verdict.code = nf_nat_masquerade_ipv4(pkt->skb, nft_hook(pkt), - &range, nft_out(pkt)); -} - -static void -nft_masq_ipv4_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) -{ - nf_ct_netns_put(ctx->net, NFPROTO_IPV4); -} - -static struct nft_expr_type nft_masq_ipv4_type; -static const struct nft_expr_ops nft_masq_ipv4_ops = { - .type = &nft_masq_ipv4_type, - .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), - .eval = nft_masq_ipv4_eval, - .init = nft_masq_init, - .destroy = nft_masq_ipv4_destroy, - .dump = nft_masq_dump, - .validate = nft_masq_validate, -}; - -static struct nft_expr_type nft_masq_ipv4_type __read_mostly = { - .family = NFPROTO_IPV4, - .name = "masq", - .ops = &nft_masq_ipv4_ops, - .policy = nft_masq_policy, - .maxattr = NFTA_MASQ_MAX, - .owner = THIS_MODULE, -}; - -static int __init nft_masq_ipv4_module_init(void) -{ - int ret; - - ret = nft_register_expr(&nft_masq_ipv4_type); - if (ret < 0) - return ret; - - ret = nf_nat_masquerade_ipv4_register_notifier(); - if (ret) - nft_unregister_expr(&nft_masq_ipv4_type); - - return ret; -} - -static void __exit nft_masq_ipv4_module_exit(void) -{ - nft_unregister_expr(&nft_masq_ipv4_type); - nf_nat_masquerade_ipv4_unregister_notifier(); -} - -module_init(nft_masq_ipv4_module_init); -module_exit(nft_masq_ipv4_module_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Arturo Borrero Gonzalez <arturo@debian.org"); -MODULE_ALIAS_NFT_AF_EXPR(AF_INET, "masq"); diff --git a/net/ipv4/netfilter/nft_redir_ipv4.c b/net/ipv4/netfilter/nft_redir_ipv4.c deleted file mode 100644 index 5120be1d3118..000000000000 --- a/net/ipv4/netfilter/nft_redir_ipv4.c +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2014 Arturo Borrero Gonzalez <arturo@debian.org> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - */ - -#include <linux/kernel.h> -#include <linux/init.h> -#include <linux/module.h> -#include <linux/netlink.h> -#include <linux/netfilter.h> -#include <linux/netfilter/nf_tables.h> -#include <net/netfilter/nf_tables.h> -#include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_redirect.h> -#include <net/netfilter/nft_redir.h> - -static void nft_redir_ipv4_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) -{ - struct nft_redir *priv = nft_expr_priv(expr); - struct nf_nat_ipv4_multi_range_compat mr; - - memset(&mr, 0, sizeof(mr)); - if (priv->sreg_proto_min) { - mr.range[0].min.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_min]); - mr.range[0].max.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_max]); - mr.range[0].flags |= NF_NAT_RANGE_PROTO_SPECIFIED; - } - - mr.range[0].flags |= priv->flags; - - regs->verdict.code = nf_nat_redirect_ipv4(pkt->skb, &mr, nft_hook(pkt)); -} - -static void -nft_redir_ipv4_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) -{ - nf_ct_netns_put(ctx->net, NFPROTO_IPV4); -} - -static struct nft_expr_type nft_redir_ipv4_type; -static const struct nft_expr_ops nft_redir_ipv4_ops = { - .type = &nft_redir_ipv4_type, - .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), - .eval = nft_redir_ipv4_eval, - .init = nft_redir_init, - .destroy = nft_redir_ipv4_destroy, - .dump = nft_redir_dump, - .validate = nft_redir_validate, -}; - -static struct nft_expr_type nft_redir_ipv4_type __read_mostly = { - .family = NFPROTO_IPV4, - .name = "redir", - .ops = &nft_redir_ipv4_ops, - .policy = nft_redir_policy, - .maxattr = NFTA_REDIR_MAX, - .owner = THIS_MODULE, -}; - -static int __init nft_redir_ipv4_module_init(void) -{ - return nft_register_expr(&nft_redir_ipv4_type); -} - -static void __exit nft_redir_ipv4_module_exit(void) -{ - nft_unregister_expr(&nft_redir_ipv4_type); -} - -module_init(nft_redir_ipv4_module_init); -module_exit(nft_redir_ipv4_module_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Arturo Borrero Gonzalez <arturo@debian.org>"); -MODULE_ALIAS_NFT_AF_EXPR(AF_INET, "redir"); diff --git a/net/ipv4/netlink.c b/net/ipv4/netlink.c index f86bb4f06609..d8e3a1fb8e82 100644 --- a/net/ipv4/netlink.c +++ b/net/ipv4/netlink.c @@ -3,9 +3,10 @@ #include <linux/types.h> #include <net/net_namespace.h> #include <net/netlink.h> +#include <linux/in6.h> #include <net/ip.h> -int rtm_getroute_parse_ip_proto(struct nlattr *attr, u8 *ip_proto, +int rtm_getroute_parse_ip_proto(struct nlattr *attr, u8 *ip_proto, u8 family, struct netlink_ext_ack *extack) { *ip_proto = nla_get_u8(attr); @@ -13,11 +14,19 @@ int rtm_getroute_parse_ip_proto(struct nlattr *attr, u8 *ip_proto, switch (*ip_proto) { case IPPROTO_TCP: case IPPROTO_UDP: + return 0; case IPPROTO_ICMP: + if (family != AF_INET) + break; + return 0; +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + if (family != AF_INET6) + break; return 0; - default: - NL_SET_ERR_MSG(extack, "Unsupported ip proto"); - return -EOPNOTSUPP; +#endif } + NL_SET_ERR_MSG(extack, "Unsupported ip proto"); + return -EOPNOTSUPP; } EXPORT_SYMBOL_GPL(rtm_getroute_parse_ip_proto); diff --git a/net/ipv4/route.c b/net/ipv4/route.c index ce92f73cf104..a5da63e5faa2 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -887,13 +887,15 @@ void ip_rt_send_redirect(struct sk_buff *skb) /* No redirected packets during ip_rt_redirect_silence; * reset the algorithm. */ - if (time_after(jiffies, peer->rate_last + ip_rt_redirect_silence)) + if (time_after(jiffies, peer->rate_last + ip_rt_redirect_silence)) { peer->rate_tokens = 0; + peer->n_redirects = 0; + } /* Too many ignored redirects; do not send anything * set dst.rate_last to the last seen redirected packet. */ - if (peer->rate_tokens >= ip_rt_redirect_number) { + if (peer->n_redirects >= ip_rt_redirect_number) { peer->rate_last = jiffies; goto out_put_peer; } @@ -910,6 +912,7 @@ void ip_rt_send_redirect(struct sk_buff *skb) icmp_send(skb, ICMP_REDIRECT, ICMP_REDIR_HOST, gw); peer->rate_last = jiffies; ++peer->rate_tokens; + ++peer->n_redirects; #ifdef CONFIG_IP_ROUTE_VERBOSE if (log_martians && peer->rate_tokens == ip_rt_redirect_number) @@ -1300,6 +1303,10 @@ static void ip_del_fnhe(struct fib_nh *nh, __be32 daddr) if (fnhe->fnhe_daddr == daddr) { rcu_assign_pointer(*fnhe_p, rcu_dereference_protected( fnhe->fnhe_next, lockdep_is_held(&fnhe_lock))); + /* set fnhe_daddr to 0 to ensure it won't bind with + * new dsts in rt_bind_exception(). + */ + fnhe->fnhe_daddr = 0; fnhe_flush_routes(fnhe); kfree_rcu(fnhe, rcu); break; @@ -1608,7 +1615,8 @@ int ip_mc_validate_source(struct sk_buff *skb, __be32 daddr, __be32 saddr, return -EINVAL; if (ipv4_is_zeronet(saddr)) { - if (!ipv4_is_local_multicast(daddr)) + if (!ipv4_is_local_multicast(daddr) && + ip_hdr(skb)->protocol != IPPROTO_IGMP) return -EINVAL; } else { err = fib_validate_source(skb, saddr, 0, tos, 0, dev, @@ -1816,6 +1824,7 @@ out: int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, const struct sk_buff *skb, struct flow_keys *flkeys) { + u32 multipath_hash = fl4 ? fl4->flowi4_multipath_hash : 0; struct flow_keys hash_keys; u32 mhash; @@ -1866,6 +1875,9 @@ int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, } mhash = flow_hash_from_keys(&hash_keys); + if (multipath_hash) + mhash = jhash_2words(mhash, multipath_hash, 0); + return mhash >> 1; } #endif /* CONFIG_IP_ROUTE_MULTIPATH */ @@ -2141,12 +2153,13 @@ int ip_route_input_rcu(struct sk_buff *skb, __be32 daddr, __be32 saddr, int our = 0; int err = -EINVAL; - if (in_dev) - our = ip_check_mc_rcu(in_dev, daddr, saddr, - ip_hdr(skb)->protocol); + if (!in_dev) + return err; + our = ip_check_mc_rcu(in_dev, daddr, saddr, + ip_hdr(skb)->protocol); /* check l3 master if no match yet */ - if ((!in_dev || !our) && netif_is_l3_slave(dev)) { + if (!our && netif_is_l3_slave(dev)) { struct in_device *l3_in_dev; l3_in_dev = __in_dev_get_rcu(skb->dev); @@ -2763,6 +2776,75 @@ static struct sk_buff *inet_rtm_getroute_build_skb(__be32 src, __be32 dst, return skb; } +static int inet_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG(extack, + "ipv4: Invalid header for route get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 32) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 32) || + rtm->rtm_table || rtm->rtm_protocol || + rtm->rtm_scope || rtm->rtm_type) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for route get request"); + return -EINVAL; + } + + if (rtm->rtm_flags & ~(RTM_F_NOTIFY | + RTM_F_LOOKUP_TABLE | + RTM_F_FIB_MATCH)) { + NL_SET_ERR_MSG(extack, "ipv4: Unsupported rtm_flags for route get request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + if (err) + return err; + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG(extack, "ipv4: rtm_src_len and rtm_dst_len must be 32 for IPv4"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_IIF: + case RTA_OIF: + case RTA_SRC: + case RTA_DST: + case RTA_IP_PROTO: + case RTA_SPORT: + case RTA_DPORT: + case RTA_MARK: + case RTA_UID: + break; + default: + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in route get request"); + return -EINVAL; + } + } + + return 0; +} + static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -2783,8 +2865,7 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err; int mark; - err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_ipv4_policy, - extack); + err = inet_rtm_valid_getroute_req(in_skb, nlh, tb, extack); if (err < 0) return err; @@ -2800,7 +2881,7 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, if (tb[RTA_IP_PROTO]) { err = rtm_getroute_parse_ip_proto(tb[RTA_IP_PROTO], - &ip_proto, extack); + &ip_proto, AF_INET, extack); if (err) return err; } diff --git a/net/ipv4/syncookies.c b/net/ipv4/syncookies.c index 606f868d9f3f..e531344611a0 100644 --- a/net/ipv4/syncookies.c +++ b/net/ipv4/syncookies.c @@ -216,7 +216,12 @@ struct sock *tcp_get_cookie_sock(struct sock *sk, struct sk_buff *skb, refcount_set(&req->rsk_refcnt, 1); tcp_sk(child)->tsoffset = tsoff; sock_rps_save_rxhash(child, skb); - inet_csk_reqsk_queue_add(sk, req, child); + if (!inet_csk_reqsk_queue_add(sk, req, child)) { + bh_unlock_sock(child); + sock_put(child); + child = NULL; + reqsk_put(req); + } } else { reqsk_free(req); } diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 2079145a3b7c..6baa6dc1b13b 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -943,6 +943,10 @@ ssize_t do_tcp_sendpages(struct sock *sk, struct page *page, int offset, ssize_t copied; long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + if (IS_ENABLED(CONFIG_DEBUG_VM) && + WARN_ONCE(PageSlab(page), "page must not be a Slab one")) + return -EINVAL; + /* Wait for a connection to finish. One exception is TCP Fast Open * (passive side) where data is allowed to be sent before a connection * is fully established. @@ -1127,7 +1131,8 @@ void tcp_free_fastopen_req(struct tcp_sock *tp) } static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, - int *copied, size_t size) + int *copied, size_t size, + struct ubuf_info *uarg) { struct tcp_sock *tp = tcp_sk(sk); struct inet_sock *inet = inet_sk(sk); @@ -1147,6 +1152,7 @@ static int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, return -ENOBUFS; tp->fastopen_req->data = msg; tp->fastopen_req->size = size; + tp->fastopen_req->uarg = uarg; if (inet->defer_connect) { err = tcp_connect(sk); @@ -1186,11 +1192,6 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) flags = msg->msg_flags; if (flags & MSG_ZEROCOPY && size && sock_flag(sk, SOCK_ZEROCOPY)) { - if ((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) { - err = -EINVAL; - goto out_err; - } - skb = tcp_write_queue_tail(sk); uarg = sock_zerocopy_realloc(sk, size, skb_zcopy(skb)); if (!uarg) { @@ -1205,7 +1206,7 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) if (unlikely(flags & MSG_FASTOPEN || inet_sk(sk)->defer_connect) && !tp->repair) { - err = tcp_sendmsg_fastopen(sk, msg, &copied_syn, size); + err = tcp_sendmsg_fastopen(sk, msg, &copied_syn, size, uarg); if (err == -EINPROGRESS && copied_syn > 0) goto out; else if (err) @@ -1415,7 +1416,8 @@ do_fault: /* It is the one place in all of TCP, except connection * reset, where we can be unlinking the send_head. */ - tcp_check_send_head(sk, skb); + if (tcp_write_queue_empty(sk)) + tcp_chrono_stop(sk, TCP_CHRONO_BUSY); sk_wmem_free_skb(sk, skb); } @@ -1554,7 +1556,7 @@ static void tcp_cleanup_rbuf(struct sock *sk, int copied) (copied > 0 && ((icsk->icsk_ack.pending & ICSK_ACK_PUSHED2) || ((icsk->icsk_ack.pending & ICSK_ACK_PUSHED) && - !icsk->icsk_ack.pingpong)) && + !inet_csk_in_pingpong_mode(sk))) && !atomic_read(&sk->sk_rmem_alloc))) time_to_ack = true; } @@ -1847,57 +1849,78 @@ out: #endif static void tcp_update_recv_tstamps(struct sk_buff *skb, - struct scm_timestamping *tss) + struct scm_timestamping_internal *tss) { if (skb->tstamp) - tss->ts[0] = ktime_to_timespec(skb->tstamp); + tss->ts[0] = ktime_to_timespec64(skb->tstamp); else - tss->ts[0] = (struct timespec) {0}; + tss->ts[0] = (struct timespec64) {0}; if (skb_hwtstamps(skb)->hwtstamp) - tss->ts[2] = ktime_to_timespec(skb_hwtstamps(skb)->hwtstamp); + tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp); else - tss->ts[2] = (struct timespec) {0}; + tss->ts[2] = (struct timespec64) {0}; } /* Similar to __sock_recv_timestamp, but does not require an skb */ static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk, - struct scm_timestamping *tss) + struct scm_timestamping_internal *tss) { - struct timeval tv; + int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW); bool has_timestamping = false; if (tss->ts[0].tv_sec || tss->ts[0].tv_nsec) { if (sock_flag(sk, SOCK_RCVTSTAMP)) { if (sock_flag(sk, SOCK_RCVTSTAMPNS)) { - put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMPNS, - sizeof(tss->ts[0]), &tss->ts[0]); - } else { - tv.tv_sec = tss->ts[0].tv_sec; - tv.tv_usec = tss->ts[0].tv_nsec / 1000; + if (new_tstamp) { + struct __kernel_timespec kts = {tss->ts[0].tv_sec, tss->ts[0].tv_nsec}; + + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_NEW, + sizeof(kts), &kts); + } else { + struct timespec ts_old = timespec64_to_timespec(tss->ts[0]); - put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMP, - sizeof(tv), &tv); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_OLD, + sizeof(ts_old), &ts_old); + } + } else { + if (new_tstamp) { + struct __kernel_sock_timeval stv; + + stv.tv_sec = tss->ts[0].tv_sec; + stv.tv_usec = tss->ts[0].tv_nsec / 1000; + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_NEW, + sizeof(stv), &stv); + } else { + struct __kernel_old_timeval tv; + + tv.tv_sec = tss->ts[0].tv_sec; + tv.tv_usec = tss->ts[0].tv_nsec / 1000; + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_OLD, + sizeof(tv), &tv); + } } } if (sk->sk_tsflags & SOF_TIMESTAMPING_SOFTWARE) has_timestamping = true; else - tss->ts[0] = (struct timespec) {0}; + tss->ts[0] = (struct timespec64) {0}; } if (tss->ts[2].tv_sec || tss->ts[2].tv_nsec) { if (sk->sk_tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) has_timestamping = true; else - tss->ts[2] = (struct timespec) {0}; + tss->ts[2] = (struct timespec64) {0}; } if (has_timestamping) { - tss->ts[1] = (struct timespec) {0}; - put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMPING, - sizeof(*tss), tss); + tss->ts[1] = (struct timespec64) {0}; + if (sock_flag(sk, SOCK_TSTAMP_NEW)) + put_cmsg_scm_timestamping64(msg, tss); + else + put_cmsg_scm_timestamping(msg, tss); } } @@ -1914,6 +1937,11 @@ static int tcp_inq_hint(struct sock *sk) inq = tp->rcv_nxt - tp->copied_seq; release_sock(sk); } + /* After receiving a FIN, tell the user-space to continue reading + * by returning a non-zero inq. + */ + if (inq == 0 && sock_flag(sk, SOCK_DONE)) + inq = 1; return inq; } @@ -1938,7 +1966,7 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, long timeo; struct sk_buff *skb, *last; u32 urg_hole = 0; - struct scm_timestamping tss; + struct scm_timestamping_internal tss; bool has_tss = false; bool has_cmsg; @@ -2528,6 +2556,7 @@ void tcp_write_queue_purge(struct sock *sk) sk_mem_reclaim(sk); tcp_clear_all_retrans_hints(tcp_sk(sk)); tcp_sk(sk)->packets_out = 0; + inet_csk(sk)->icsk_backoff = 0; } int tcp_disconnect(struct sock *sk, int flags) @@ -2572,6 +2601,7 @@ int tcp_disconnect(struct sock *sk, int flags) sk->sk_shutdown = 0; sock_reset_flag(sk, SOCK_DONE); tp->srtt_us = 0; + tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); tp->rcv_rtt_last_tsecr = 0; tp->write_seq += tp->max_window + 2; if (tp->write_seq == 0) @@ -2579,7 +2609,9 @@ int tcp_disconnect(struct sock *sk, int flags) icsk->icsk_backoff = 0; tp->snd_cwnd = 2; icsk->icsk_probes_out = 0; + icsk->icsk_rto = TCP_TIMEOUT_INIT; tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + tp->snd_cwnd = TCP_INIT_CWND; tp->snd_cwnd_cnt = 0; tp->window_clamp = 0; tp->delivered_ce = 0; @@ -2603,6 +2635,23 @@ int tcp_disconnect(struct sock *sk, int flags) tp->duplicate_sack[0].end_seq = 0; tp->dsack_dups = 0; tp->reord_seen = 0; + tp->retrans_out = 0; + tp->sacked_out = 0; + tp->tlp_high_seq = 0; + tp->last_oow_ack_time = 0; + /* There's a bubble in the pipe until at least the first ACK. */ + tp->app_limited = ~0U; + tp->rack.mstamp = 0; + tp->rack.advanced = 0; + tp->rack.reo_wnd_steps = 1; + tp->rack.last_delivered = 0; + tp->rack.reo_wnd_persist = 0; + tp->rack.dsack_seen = 0; + tp->syn_data_acked = 0; + tp->rx_opt.saw_tstamp = 0; + tp->rx_opt.dsack = 0; + tp->rx_opt.num_sacks = 0; + /* Clean up fastopen related fields */ tcp_free_fastopen_req(tp); @@ -2968,16 +3017,16 @@ static int do_tcp_setsockopt(struct sock *sk, int level, case TCP_QUICKACK: if (!val) { - icsk->icsk_ack.pingpong = 1; + inet_csk_enter_pingpong_mode(sk); } else { - icsk->icsk_ack.pingpong = 0; + inet_csk_exit_pingpong_mode(sk); if ((1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT) && inet_csk_ack_scheduled(sk)) { icsk->icsk_ack.pending |= ICSK_ACK_PUSHED; tcp_cleanup_rbuf(sk, 1); if (!(val & 1)) - icsk->icsk_ack.pingpong = 1; + inet_csk_enter_pingpong_mode(sk); } } break; @@ -3391,7 +3440,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, return 0; } case TCP_QUICKACK: - val = !icsk->icsk_ack.pingpong; + val = !inet_csk_in_pingpong_mode(sk); break; case TCP_CONGESTION: @@ -3659,7 +3708,7 @@ bool tcp_alloc_md5sig_pool(void) if (!tcp_md5sig_pool_populated) { __tcp_alloc_md5sig_pool(); if (tcp_md5sig_pool_populated) - static_key_slow_inc(&tcp_md5_needed); + static_branch_inc(&tcp_md5_needed); } mutex_unlock(&tcp_md5sig_mutex); diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index 0f497fc49c3f..56be7d27f208 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -115,6 +115,14 @@ struct bbr { unused_b:5; u32 prior_cwnd; /* prior cwnd upon entering loss recovery */ u32 full_bw; /* recent bw, to estimate if pipe is full */ + + /* For tracking ACK aggregation: */ + u64 ack_epoch_mstamp; /* start of ACK sampling epoch */ + u16 extra_acked[2]; /* max excess data ACKed in epoch */ + u32 ack_epoch_acked:20, /* packets (S)ACKed in sampling epoch */ + extra_acked_win_rtts:5, /* age of extra_acked, in round trips */ + extra_acked_win_idx:1, /* current index in extra_acked array */ + unused_c:6; }; #define CYCLE_LEN 8 /* number of phases in a pacing gain cycle */ @@ -182,6 +190,15 @@ static const u32 bbr_lt_bw_diff = 4000 / 8; /* If we estimate we're policed, use lt_bw for this many round trips: */ static const u32 bbr_lt_bw_max_rtts = 48; +/* Gain factor for adding extra_acked to target cwnd: */ +static const int bbr_extra_acked_gain = BBR_UNIT; +/* Window length of extra_acked window. */ +static const u32 bbr_extra_acked_win_rtts = 5; +/* Max allowed val for ack_epoch_acked, after which sampling epoch is reset */ +static const u32 bbr_ack_epoch_acked_reset_thresh = 1U << 20; +/* Time period for clamping cwnd increment due to ack aggregation */ +static const u32 bbr_extra_acked_max_us = 100 * 1000; + static void bbr_check_probe_rtt_done(struct sock *sk); /* Do we estimate that STARTUP filled the pipe? */ @@ -208,6 +225,16 @@ static u32 bbr_bw(const struct sock *sk) return bbr->lt_use_bw ? bbr->lt_bw : bbr_max_bw(sk); } +/* Return maximum extra acked in past k-2k round trips, + * where k = bbr_extra_acked_win_rtts. + */ +static u16 bbr_extra_acked(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return max(bbr->extra_acked[0], bbr->extra_acked[1]); +} + /* Return rate in bytes per second, optionally with a gain. * The order here is chosen carefully to avoid overflow of u64. This should * work for input rates of up to 2.9Tbit/sec and gain of 2.89x. @@ -305,6 +332,8 @@ static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) if (event == CA_EVENT_TX_START && tp->app_limited) { bbr->idle_restart = 1; + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; /* Avoid pointless buffer overflows: pace at est. bw if we don't * need more speed (we're restarting from idle and app-limited). */ @@ -315,30 +344,19 @@ static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) } } -/* Find target cwnd. Right-size the cwnd based on min RTT and the - * estimated bottleneck bandwidth: +/* Calculate bdp based on min RTT and the estimated bottleneck bandwidth: * - * cwnd = bw * min_rtt * gain = BDP * gain + * bdp = bw * min_rtt * gain * * The key factor, gain, controls the amount of queue. While a small gain * builds a smaller queue, it becomes more vulnerable to noise in RTT * measurements (e.g., delayed ACKs or other ACK compression effects). This * noise may cause BBR to under-estimate the rate. - * - * To achieve full performance in high-speed paths, we budget enough cwnd to - * fit full-sized skbs in-flight on both end hosts to fully utilize the path: - * - one skb in sending host Qdisc, - * - one skb in sending host TSO/GSO engine - * - one skb being received by receiver host LRO/GRO/delayed-ACK engine - * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because - * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, - * which allows 2 outstanding 2-packet sequences, to try to keep pipe - * full even with ACK-every-other-packet delayed ACKs. */ -static u32 bbr_target_cwnd(struct sock *sk, u32 bw, int gain) +static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) { struct bbr *bbr = inet_csk_ca(sk); - u32 cwnd; + u32 bdp; u64 w; /* If we've never had a valid RTT sample, cap cwnd at the initial @@ -353,7 +371,24 @@ static u32 bbr_target_cwnd(struct sock *sk, u32 bw, int gain) w = (u64)bw * bbr->min_rtt_us; /* Apply a gain to the given value, then remove the BW_SCALE shift. */ - cwnd = (((w * gain) >> BBR_SCALE) + BW_UNIT - 1) / BW_UNIT; + bdp = (((w * gain) >> BBR_SCALE) + BW_UNIT - 1) / BW_UNIT; + + return bdp; +} + +/* To achieve full performance in high-speed paths, we budget enough cwnd to + * fit full-sized skbs in-flight on both end hosts to fully utilize the path: + * - one skb in sending host Qdisc, + * - one skb in sending host TSO/GSO engine + * - one skb being received by receiver host LRO/GRO/delayed-ACK engine + * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because + * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, + * which allows 2 outstanding 2-packet sequences, to try to keep pipe + * full even with ACK-every-other-packet delayed ACKs. + */ +static u32 bbr_quantization_budget(struct sock *sk, u32 cwnd, int gain) +{ + struct bbr *bbr = inet_csk_ca(sk); /* Allow enough full-sized skbs in flight to utilize end systems. */ cwnd += 3 * bbr_tso_segs_goal(sk); @@ -368,6 +403,17 @@ static u32 bbr_target_cwnd(struct sock *sk, u32 bw, int gain) return cwnd; } +/* Find inflight based on min RTT and the estimated bottleneck bandwidth. */ +static u32 bbr_inflight(struct sock *sk, u32 bw, int gain) +{ + u32 inflight; + + inflight = bbr_bdp(sk, bw, gain); + inflight = bbr_quantization_budget(sk, inflight, gain); + + return inflight; +} + /* With pacing at lower layers, there's often less data "in the network" than * "in flight". With TSQ and departure time pacing at lower layers (e.g. fq), * we often have several skbs queued in the pacing layer with a pre-scheduled @@ -401,6 +447,22 @@ static u32 bbr_packets_in_net_at_edt(struct sock *sk, u32 inflight_now) return inflight_at_edt - interval_delivered; } +/* Find the cwnd increment based on estimate of ack aggregation */ +static u32 bbr_ack_aggregation_cwnd(struct sock *sk) +{ + u32 max_aggr_cwnd, aggr_cwnd = 0; + + if (bbr_extra_acked_gain && bbr_full_bw_reached(sk)) { + max_aggr_cwnd = ((u64)bbr_bw(sk) * bbr_extra_acked_max_us) + / BW_UNIT; + aggr_cwnd = (bbr_extra_acked_gain * bbr_extra_acked(sk)) + >> BBR_SCALE; + aggr_cwnd = min(aggr_cwnd, max_aggr_cwnd); + } + + return aggr_cwnd; +} + /* An optimization in BBR to reduce losses: On the first round of recovery, we * follow the packet conservation principle: send P packets per P packets acked. * After that, we slow-start and send at most 2*P packets per P packets acked. @@ -461,8 +523,15 @@ static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, if (bbr_set_cwnd_to_recover_or_restore(sk, rs, acked, &cwnd)) goto done; + target_cwnd = bbr_bdp(sk, bw, gain); + + /* Increment the cwnd to account for excess ACKed data that seems + * due to aggregation (of data and/or ACKs) visible in the ACK stream. + */ + target_cwnd += bbr_ack_aggregation_cwnd(sk); + target_cwnd = bbr_quantization_budget(sk, target_cwnd, gain); + /* If we're below target cwnd, slow start cwnd toward target cwnd. */ - target_cwnd = bbr_target_cwnd(sk, bw, gain); if (bbr_full_bw_reached(sk)) /* only cut cwnd if we filled the pipe */ cwnd = min(cwnd + acked, target_cwnd); else if (cwnd < target_cwnd || tp->delivered < TCP_INIT_CWND) @@ -503,14 +572,14 @@ static bool bbr_is_next_cycle_phase(struct sock *sk, if (bbr->pacing_gain > BBR_UNIT) return is_full_length && (rs->losses || /* perhaps pacing_gain*BDP won't fit */ - inflight >= bbr_target_cwnd(sk, bw, bbr->pacing_gain)); + inflight >= bbr_inflight(sk, bw, bbr->pacing_gain)); /* A pacing_gain < 1.0 tries to drain extra queue we added if bw * probing didn't find more bw. If inflight falls to match BDP then we * estimate queue is drained; persisting would underutilize the pipe. */ return is_full_length || - inflight <= bbr_target_cwnd(sk, bw, BBR_UNIT); + inflight <= bbr_inflight(sk, bw, BBR_UNIT); } static void bbr_advance_cycle_phase(struct sock *sk) @@ -727,6 +796,67 @@ static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) } } +/* Estimates the windowed max degree of ack aggregation. + * This is used to provision extra in-flight data to keep sending during + * inter-ACK silences. + * + * Degree of ack aggregation is estimated as extra data acked beyond expected. + * + * max_extra_acked = "maximum recent excess data ACKed beyond max_bw * interval" + * cwnd += max_extra_acked + * + * Max extra_acked is clamped by cwnd and bw * bbr_extra_acked_max_us (100 ms). + * Max filter is an approximate sliding window of 5-10 (packet timed) round + * trips. + */ +static void bbr_update_ack_aggregation(struct sock *sk, + const struct rate_sample *rs) +{ + u32 epoch_us, expected_acked, extra_acked; + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (!bbr_extra_acked_gain || rs->acked_sacked <= 0 || + rs->delivered < 0 || rs->interval_us <= 0) + return; + + if (bbr->round_start) { + bbr->extra_acked_win_rtts = min(0x1F, + bbr->extra_acked_win_rtts + 1); + if (bbr->extra_acked_win_rtts >= bbr_extra_acked_win_rtts) { + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = bbr->extra_acked_win_idx ? + 0 : 1; + bbr->extra_acked[bbr->extra_acked_win_idx] = 0; + } + } + + /* Compute how many packets we expected to be delivered over epoch. */ + epoch_us = tcp_stamp_us_delta(tp->delivered_mstamp, + bbr->ack_epoch_mstamp); + expected_acked = ((u64)bbr_bw(sk) * epoch_us) / BW_UNIT; + + /* Reset the aggregation epoch if ACK rate is below expected rate or + * significantly large no. of ack received since epoch (potentially + * quite old epoch). + */ + if (bbr->ack_epoch_acked <= expected_acked || + (bbr->ack_epoch_acked + rs->acked_sacked >= + bbr_ack_epoch_acked_reset_thresh)) { + bbr->ack_epoch_acked = 0; + bbr->ack_epoch_mstamp = tp->delivered_mstamp; + expected_acked = 0; + } + + /* Compute excess data delivered, beyond what was expected. */ + bbr->ack_epoch_acked = min_t(u32, 0xFFFFF, + bbr->ack_epoch_acked + rs->acked_sacked); + extra_acked = bbr->ack_epoch_acked - expected_acked; + extra_acked = min(extra_acked, tp->snd_cwnd); + if (extra_acked > bbr->extra_acked[bbr->extra_acked_win_idx]) + bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; +} + /* Estimate when the pipe is full, using the change in delivery rate: BBR * estimates that STARTUP filled the pipe if the estimated bw hasn't changed by * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited @@ -762,11 +892,11 @@ static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs) if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { bbr->mode = BBR_DRAIN; /* drain queue we created */ tcp_sk(sk)->snd_ssthresh = - bbr_target_cwnd(sk, bbr_max_bw(sk), BBR_UNIT); + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); } /* fall through to check if in-flight is already small: */ if (bbr->mode == BBR_DRAIN && bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= - bbr_target_cwnd(sk, bbr_max_bw(sk), BBR_UNIT)) + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) bbr_reset_probe_bw_mode(sk); /* we estimate queue is drained */ } @@ -881,6 +1011,7 @@ static void bbr_update_gains(struct sock *sk) static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) { bbr_update_bw(sk, rs); + bbr_update_ack_aggregation(sk, rs); bbr_update_cycle_phase(sk, rs); bbr_check_full_bw_reached(sk, rs); bbr_check_drain(sk, rs); @@ -932,6 +1063,13 @@ static void bbr_init(struct sock *sk) bbr_reset_lt_bw_sampling(sk); bbr_reset_startup_mode(sk); + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = 0; + bbr->extra_acked[0] = 0; + bbr->extra_acked[1] = 0; + cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); } diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 76858b14ebe9..5def3c48870e 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -221,7 +221,7 @@ void tcp_enter_quickack_mode(struct sock *sk, unsigned int max_quickacks) struct inet_connection_sock *icsk = inet_csk(sk); tcp_incr_quickack(sk, max_quickacks); - icsk->icsk_ack.pingpong = 0; + inet_csk_exit_pingpong_mode(sk); icsk->icsk_ack.ato = TCP_ATO_MIN; } EXPORT_SYMBOL(tcp_enter_quickack_mode); @@ -236,7 +236,7 @@ static bool tcp_in_quickack_mode(struct sock *sk) const struct dst_entry *dst = __sk_dst_get(sk); return (dst && dst_metric(dst, RTAX_QUICKACK)) || - (icsk->icsk_ack.quick && !icsk->icsk_ack.pingpong); + (icsk->icsk_ack.quick && !inet_csk_in_pingpong_mode(sk)); } static void tcp_ecn_queue_cwr(struct tcp_sock *tp) @@ -1574,9 +1574,7 @@ static struct sk_buff *tcp_sacktag_walk(struct sk_buff *skb, struct sock *sk, return skb; } -static struct sk_buff *tcp_sacktag_bsearch(struct sock *sk, - struct tcp_sacktag_state *state, - u32 seq) +static struct sk_buff *tcp_sacktag_bsearch(struct sock *sk, u32 seq) { struct rb_node *parent, **p = &sk->tcp_rtx_queue.rb_node; struct sk_buff *skb; @@ -1598,13 +1596,12 @@ static struct sk_buff *tcp_sacktag_bsearch(struct sock *sk, } static struct sk_buff *tcp_sacktag_skip(struct sk_buff *skb, struct sock *sk, - struct tcp_sacktag_state *state, u32 skip_to_seq) { if (skb && after(TCP_SKB_CB(skb)->seq, skip_to_seq)) return skb; - return tcp_sacktag_bsearch(sk, state, skip_to_seq); + return tcp_sacktag_bsearch(sk, skip_to_seq); } static struct sk_buff *tcp_maybe_skipping_dsack(struct sk_buff *skb, @@ -1617,7 +1614,7 @@ static struct sk_buff *tcp_maybe_skipping_dsack(struct sk_buff *skb, return skb; if (before(next_dup->start_seq, skip_to_seq)) { - skb = tcp_sacktag_skip(skb, sk, state, next_dup->start_seq); + skb = tcp_sacktag_skip(skb, sk, next_dup->start_seq); skb = tcp_sacktag_walk(skb, sk, NULL, state, next_dup->start_seq, next_dup->end_seq, 1); @@ -1758,8 +1755,7 @@ tcp_sacktag_write_queue(struct sock *sk, const struct sk_buff *ack_skb, /* Head todo? */ if (before(start_seq, cache->start_seq)) { - skb = tcp_sacktag_skip(skb, sk, state, - start_seq); + skb = tcp_sacktag_skip(skb, sk, start_seq); skb = tcp_sacktag_walk(skb, sk, next_dup, state, start_seq, @@ -1785,7 +1781,7 @@ tcp_sacktag_write_queue(struct sock *sk, const struct sk_buff *ack_skb, goto walk; } - skb = tcp_sacktag_skip(skb, sk, state, cache->end_seq); + skb = tcp_sacktag_skip(skb, sk, cache->end_seq); /* Check overlap against next cached too (past this one already) */ cache++; continue; @@ -1796,7 +1792,7 @@ tcp_sacktag_write_queue(struct sock *sk, const struct sk_buff *ack_skb, if (!skb) break; } - skb = tcp_sacktag_skip(skb, sk, state, start_seq); + skb = tcp_sacktag_skip(skb, sk, start_seq); walk: skb = tcp_sacktag_walk(skb, sk, next_dup, state, @@ -3595,7 +3591,7 @@ static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) * this segment (RFC793 Section 3.9). */ if (after(ack, tp->snd_nxt)) - goto invalid_ack; + return -1; if (after(ack, prior_snd_una)) { flag |= FLAG_SND_UNA_ADVANCED; @@ -3714,10 +3710,6 @@ no_queue: tcp_process_tlp_ack(sk, ack, flag); return 1; -invalid_ack: - SOCK_DEBUG(sk, "Ack %u after %u:%u\n", ack, tp->snd_una, tp->snd_nxt); - return -1; - old_ack: /* If data was SACKed, tag it and see if we should send more data. * If data was DSACKed, see if we can undo a cwnd reduction. @@ -3731,7 +3723,6 @@ old_ack: tcp_xmit_recovery(sk, rexmit); } - SOCK_DEBUG(sk, "Ack %u before %u:%u\n", ack, tp->snd_una, tp->snd_nxt); return 0; } @@ -4094,7 +4085,7 @@ void tcp_fin(struct sock *sk) case TCP_ESTABLISHED: /* Move to CLOSE_WAIT */ tcp_set_state(sk, TCP_CLOSE_WAIT); - inet_csk(sk)->icsk_ack.pingpong = 1; + inet_csk_enter_pingpong_mode(sk); break; case TCP_CLOSE_WAIT: @@ -4432,13 +4423,9 @@ static void tcp_ofo_queue(struct sock *sk) rb_erase(&skb->rbnode, &tp->out_of_order_queue); if (unlikely(!after(TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt))) { - SOCK_DEBUG(sk, "ofo packet was already received\n"); tcp_drop(sk, skb); continue; } - SOCK_DEBUG(sk, "ofo requeuing : rcv_next %X seq %X - %X\n", - tp->rcv_nxt, TCP_SKB_CB(skb)->seq, - TCP_SKB_CB(skb)->end_seq); tail = skb_peek_tail(&sk->sk_receive_queue); eaten = tail && tcp_try_coalesce(sk, tail, skb, &fragstolen); @@ -4502,8 +4489,6 @@ static void tcp_data_queue_ofo(struct sock *sk, struct sk_buff *skb) NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFOQUEUE); seq = TCP_SKB_CB(skb)->seq; end_seq = TCP_SKB_CB(skb)->end_seq; - SOCK_DEBUG(sk, "out of order segment: rcv_next %X seq %X - %X\n", - tp->rcv_nxt, seq, end_seq); p = &tp->out_of_order_queue.rb_node; if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) { @@ -4779,10 +4764,6 @@ drop: if (before(TCP_SKB_CB(skb)->seq, tp->rcv_nxt)) { /* Partial packet, seq < rcv_next < end_seq */ - SOCK_DEBUG(sk, "partial packet: rcv_next %X seq %X - %X\n", - tp->rcv_nxt, TCP_SKB_CB(skb)->seq, - TCP_SKB_CB(skb)->end_seq); - tcp_dsack_set(sk, TCP_SKB_CB(skb)->seq, tp->rcv_nxt); /* If window is closed, drop tail of packet. But after @@ -5061,8 +5042,6 @@ static int tcp_prune_queue(struct sock *sk) { struct tcp_sock *tp = tcp_sk(sk); - SOCK_DEBUG(sk, "prune_queue: c=%x\n", tp->copied_seq); - NET_INC_STATS(sock_net(sk), LINUX_MIB_PRUNECALLED); if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) @@ -5889,7 +5868,7 @@ static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, return -1; if (sk->sk_write_pending || icsk->icsk_accept_queue.rskq_defer_accept || - icsk->icsk_ack.pingpong) { + inet_csk_in_pingpong_mode(sk)) { /* Save one ACK. Data will be ready after * several ticks, if write_pending is set. * @@ -6519,7 +6498,13 @@ int tcp_conn_request(struct request_sock_ops *rsk_ops, af_ops->send_synack(fastopen_sk, dst, &fl, req, &foc, TCP_SYNACK_FASTOPEN); /* Add the child socket directly into the accept queue */ - inet_csk_reqsk_queue_add(sk, req, fastopen_sk); + if (!inet_csk_reqsk_queue_add(sk, req, fastopen_sk)) { + reqsk_fastopen_remove(fastopen_sk, req, false); + bh_unlock_sock(fastopen_sk); + sock_put(fastopen_sk); + reqsk_put(req); + goto drop; + } sk->sk_data_ready(sk); bh_unlock_sock(fastopen_sk); sock_put(fastopen_sk); diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index efc6fef692ff..277d71239d75 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -536,12 +536,15 @@ int tcp_v4_err(struct sk_buff *icmp_skb, u32 info) if (sock_owned_by_user(sk)) break; + skb = tcp_rtx_queue_head(sk); + if (WARN_ON_ONCE(!skb)) + break; + icsk->icsk_backoff--; icsk->icsk_rto = tp->srtt_us ? __tcp_set_rto(tp) : TCP_TIMEOUT_INIT; icsk->icsk_rto = inet_csk_rto_backoff(icsk, TCP_RTO_MAX); - skb = tcp_rtx_queue_head(sk); tcp_mstamp_refresh(tp); delta_us = (u32)(tp->tcp_mstamp - tcp_skb_timestamp_us(skb)); @@ -970,7 +973,7 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req) * We need to maintain these in the sk structure. */ -struct static_key tcp_md5_needed __read_mostly; +DEFINE_STATIC_KEY_FALSE(tcp_md5_needed); EXPORT_SYMBOL(tcp_md5_needed); /* Find the Key structure for an address. */ @@ -1731,15 +1734,8 @@ EXPORT_SYMBOL(tcp_add_backlog); int tcp_filter(struct sock *sk, struct sk_buff *skb) { struct tcphdr *th = (struct tcphdr *)skb->data; - unsigned int eaten = skb->len; - int err; - err = sk_filter_trim_cap(sk, skb, th->doff * 4); - if (!err) { - eaten -= skb->len; - TCP_SKB_CB(skb)->end_seq -= eaten; - } - return err; + return sk_filter_trim_cap(sk, skb, th->doff * 4); } EXPORT_SYMBOL(tcp_filter); @@ -2437,7 +2433,7 @@ static void get_tcp4_sock(struct sock *sk, struct seq_file *f, int i) refcount_read(&sk->sk_refcnt), sk, jiffies_to_clock_t(icsk->icsk_rto), jiffies_to_clock_t(icsk->icsk_ack.ato), - (icsk->icsk_ack.quick << 1) | icsk->icsk_ack.pingpong, + (icsk->icsk_ack.quick << 1) | inet_csk_in_pingpong_mode(sk), tp->snd_cwnd, state == TCP_LISTEN ? fastopenq->max_qlen : diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c index 12affb7864d9..79900f783e0d 100644 --- a/net/ipv4/tcp_minisocks.c +++ b/net/ipv4/tcp_minisocks.c @@ -294,12 +294,15 @@ void tcp_time_wait(struct sock *sk, int state, int timeo) * so the timewait ack generating code has the key. */ do { - struct tcp_md5sig_key *key; tcptw->tw_md5_key = NULL; - key = tp->af_specific->md5_lookup(sk, sk); - if (key) { - tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); - BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); + if (static_branch_unlikely(&tcp_md5_needed)) { + struct tcp_md5sig_key *key; + + key = tp->af_specific->md5_lookup(sk, sk); + if (key) { + tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); + BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); + } } } while (0); #endif @@ -338,10 +341,12 @@ EXPORT_SYMBOL(tcp_time_wait); void tcp_twsk_destructor(struct sock *sk) { #ifdef CONFIG_TCP_MD5SIG - struct tcp_timewait_sock *twsk = tcp_twsk(sk); + if (static_branch_unlikely(&tcp_md5_needed)) { + struct tcp_timewait_sock *twsk = tcp_twsk(sk); - if (twsk->tw_md5_key) - kfree_rcu(twsk->tw_md5_key, rcu); + if (twsk->tw_md5_key) + kfree_rcu(twsk->tw_md5_key, rcu); + } #endif } EXPORT_SYMBOL_GPL(tcp_twsk_destructor); @@ -479,43 +484,16 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, tcp_init_wl(newtp, treq->rcv_isn); - newtp->srtt_us = 0; - newtp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); minmax_reset(&newtp->rtt_min, tcp_jiffies32, ~0U); - newicsk->icsk_rto = TCP_TIMEOUT_INIT; newicsk->icsk_ack.lrcvtime = tcp_jiffies32; - newtp->packets_out = 0; - newtp->retrans_out = 0; - newtp->sacked_out = 0; - newtp->snd_ssthresh = TCP_INFINITE_SSTHRESH; - newtp->tlp_high_seq = 0; newtp->lsndtime = tcp_jiffies32; newsk->sk_txhash = treq->txhash; - newtp->last_oow_ack_time = 0; newtp->total_retrans = req->num_retrans; - /* So many TCP implementations out there (incorrectly) count the - * initial SYN frame in their delayed-ACK and congestion control - * algorithms that we must have the following bandaid to talk - * efficiently to them. -DaveM - */ - newtp->snd_cwnd = TCP_INIT_CWND; - newtp->snd_cwnd_cnt = 0; - - /* There's a bubble in the pipe until at least the first ACK. */ - newtp->app_limited = ~0U; - tcp_init_xmit_timers(newsk); newtp->write_seq = newtp->pushed_seq = treq->snt_isn + 1; - newtp->rx_opt.saw_tstamp = 0; - - newtp->rx_opt.dsack = 0; - newtp->rx_opt.num_sacks = 0; - - newtp->urg_data = 0; - if (sock_flag(newsk, SOCK_KEEPOPEN)) inet_csk_reset_keepalive_timer(newsk, keepalive_time_when(newtp)); @@ -556,13 +534,6 @@ struct sock *tcp_create_openreq_child(const struct sock *sk, tcp_ecn_openreq_child(newtp, req); newtp->fastopen_req = NULL; newtp->fastopen_rsk = NULL; - newtp->syn_data_acked = 0; - newtp->rack.mstamp = 0; - newtp->rack.advanced = 0; - newtp->rack.reo_wnd_steps = 1; - newtp->rack.last_delivered = 0; - newtp->rack.reo_wnd_persist = 0; - newtp->rack.dsack_seen = 0; __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS); diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 730bc44dbad9..4522579aaca2 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -165,13 +165,16 @@ static void tcp_event_data_sent(struct tcp_sock *tp, if (tcp_packets_in_flight(tp) == 0) tcp_ca_event(sk, CA_EVENT_TX_START); - tp->lsndtime = now; - - /* If it is a reply for ato after last received - * packet, enter pingpong mode. + /* If this is the first data packet sent in response to the + * previous received data, + * and it is a reply for ato after last received packet, + * increase pingpong count. */ - if ((u32)(now - icsk->icsk_ack.lrcvtime) < icsk->icsk_ack.ato) - icsk->icsk_ack.pingpong = 1; + if (before(tp->lsndtime, icsk->icsk_ack.lrcvtime) && + (u32)(now - icsk->icsk_ack.lrcvtime) < icsk->icsk_ack.ato) + inet_csk_inc_pingpong_cnt(sk); + + tp->lsndtime = now; } /* Account for an ACK we sent. */ @@ -594,7 +597,7 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb, *md5 = NULL; #ifdef CONFIG_TCP_MD5SIG - if (static_key_false(&tcp_md5_needed) && + if (static_branch_unlikely(&tcp_md5_needed) && rcu_access_pointer(tp->md5sig_info)) { *md5 = tp->af_specific->md5_lookup(sk, sk); if (*md5) { @@ -731,7 +734,7 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb *md5 = NULL; #ifdef CONFIG_TCP_MD5SIG - if (static_key_false(&tcp_md5_needed) && + if (static_branch_unlikely(&tcp_md5_needed) && rcu_access_pointer(tp->md5sig_info)) { *md5 = tp->af_specific->md5_lookup(sk, sk); if (*md5) { @@ -980,7 +983,6 @@ static void tcp_update_skb_after_send(struct sock *sk, struct sk_buff *skb, { struct tcp_sock *tp = tcp_sk(sk); - skb->skb_mstamp_ns = tp->tcp_wstamp_ns; if (sk->sk_pacing_status != SK_PACING_NONE) { unsigned long rate = sk->sk_pacing_rate; @@ -1028,7 +1030,9 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, BUG_ON(!skb || !tcp_skb_pcount(skb)); tp = tcp_sk(sk); - + prior_wstamp = tp->tcp_wstamp_ns; + tp->tcp_wstamp_ns = max(tp->tcp_wstamp_ns, tp->tcp_clock_cache); + skb->skb_mstamp_ns = tp->tcp_wstamp_ns; if (clone_it) { TCP_SKB_CB(skb)->tx.in_flight = TCP_SKB_CB(skb)->end_seq - tp->snd_una; @@ -1045,11 +1049,6 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, return -ENOBUFS; } - prior_wstamp = tp->tcp_wstamp_ns; - tp->tcp_wstamp_ns = max(tp->tcp_wstamp_ns, tp->tcp_clock_cache); - - skb->skb_mstamp_ns = tp->tcp_wstamp_ns; - inet = inet_sk(sk); tcb = TCP_SKB_CB(skb); memset(&opts, 0, sizeof(opts)); @@ -1847,17 +1846,17 @@ static bool tcp_snd_wnd_test(const struct tcp_sock *tp, * know that all the data is in scatter-gather pages, and that the * packet has never been sent out before (and thus is not cloned). */ -static int tso_fragment(struct sock *sk, enum tcp_queue tcp_queue, - struct sk_buff *skb, unsigned int len, +static int tso_fragment(struct sock *sk, struct sk_buff *skb, unsigned int len, unsigned int mss_now, gfp_t gfp) { - struct sk_buff *buff; int nlen = skb->len - len; + struct sk_buff *buff; u8 flags; /* All of a TSO frame must be composed of paged data. */ if (skb->len != skb->data_len) - return tcp_fragment(sk, tcp_queue, skb, len, mss_now, gfp); + return tcp_fragment(sk, TCP_FRAG_IN_WRITE_QUEUE, + skb, len, mss_now, gfp); buff = sk_stream_alloc_skb(sk, 0, gfp, true); if (unlikely(!buff)) @@ -1893,7 +1892,7 @@ static int tso_fragment(struct sock *sk, enum tcp_queue tcp_queue, /* Link BUFF into the send queue. */ __skb_header_release(buff); - tcp_insert_write_queue_after(skb, buff, sk, tcp_queue); + tcp_insert_write_queue_after(skb, buff, sk, TCP_FRAG_IN_WRITE_QUEUE); return 0; } @@ -2347,6 +2346,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, /* "skb_mstamp_ns" is used as a start point for the retransmit timer */ skb->skb_mstamp_ns = tp->tcp_wstamp_ns = tp->tcp_clock_cache; list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); + tcp_init_tso_segs(skb, mss_now); goto repair; /* Skip network transmission */ } @@ -2391,8 +2391,7 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, nonagle); if (skb->len > limit && - unlikely(tso_fragment(sk, TCP_FRAG_IN_WRITE_QUEUE, - skb, limit, mss_now, gfp))) + unlikely(tso_fragment(sk, skb, limit, mss_now, gfp))) break; if (tcp_small_queue_check(sk, skb, 0)) @@ -2937,12 +2936,16 @@ int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) err = tcp_transmit_skb(sk, skb, 1, GFP_ATOMIC); } + /* To avoid taking spuriously low RTT samples based on a timestamp + * for a transmit that never happened, always mark EVER_RETRANS + */ + TCP_SKB_CB(skb)->sacked |= TCPCB_EVER_RETRANS; + if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RETRANS_CB_FLAG)) tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RETRANS_CB, TCP_SKB_CB(skb)->seq, segs, err); if (likely(!err)) { - TCP_SKB_CB(skb)->sacked |= TCPCB_EVER_RETRANS; trace_tcp_retransmit_skb(sk, skb); } else if (err != -EBUSY) { NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPRETRANSFAIL, segs); @@ -2963,13 +2966,12 @@ int tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) #endif TCP_SKB_CB(skb)->sacked |= TCPCB_RETRANS; tp->retrans_out += tcp_skb_pcount(skb); - - /* Save stamp of the first retransmit. */ - if (!tp->retrans_stamp) - tp->retrans_stamp = tcp_skb_timestamp(skb); - } + /* Save stamp of the first (attempted) retransmit. */ + if (!tp->retrans_stamp) + tp->retrans_stamp = tcp_skb_timestamp(skb); + if (tp->undo_retrans < 0) tp->undo_retrans = 0; tp->undo_retrans += tcp_skb_pcount(skb); @@ -3456,6 +3458,7 @@ static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn) skb_trim(syn_data, copied); space = copied; } + skb_zcopy_set(syn_data, fo->uarg, NULL); } /* No more data pending in inet_wait_for_connect() */ if (space == fo->size) @@ -3569,7 +3572,7 @@ void tcp_send_delayed_ack(struct sock *sk) const struct tcp_sock *tp = tcp_sk(sk); int max_ato = HZ / 2; - if (icsk->icsk_ack.pingpong || + if (inet_csk_in_pingpong_mode(sk) || (icsk->icsk_ack.pending & ICSK_ACK_PUSHED)) max_ato = TCP_DELACK_MAX; @@ -3750,7 +3753,7 @@ void tcp_send_probe0(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); struct tcp_sock *tp = tcp_sk(sk); struct net *net = sock_net(sk); - unsigned long probe_max; + unsigned long timeout; int err; err = tcp_write_wakeup(sk, LINUX_MIB_TCPWINPROBE); @@ -3762,26 +3765,18 @@ void tcp_send_probe0(struct sock *sk) return; } + icsk->icsk_probes_out++; if (err <= 0) { if (icsk->icsk_backoff < net->ipv4.sysctl_tcp_retries2) icsk->icsk_backoff++; - icsk->icsk_probes_out++; - probe_max = TCP_RTO_MAX; + timeout = tcp_probe0_when(sk, TCP_RTO_MAX); } else { /* If packet was not sent due to local congestion, - * do not backoff and do not remember icsk_probes_out. - * Let local senders to fight for local resources. - * - * Use accumulated backoff yet. + * Let senders fight for local resources conservatively. */ - if (!icsk->icsk_probes_out) - icsk->icsk_probes_out = 1; - probe_max = TCP_RESOURCE_PROBE_INTERVAL; - } - tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, - tcp_probe0_when(sk, probe_max), - TCP_RTO_MAX, - NULL); + timeout = TCP_RESOURCE_PROBE_INTERVAL; + } + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, timeout, TCP_RTO_MAX, NULL); } int tcp_rtx_synack(const struct sock *sk, struct request_sock *req) diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index 71a29e9c0620..f0c86398e6a7 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -22,28 +22,14 @@ #include <linux/gfp.h> #include <net/tcp.h> -static u32 tcp_retransmit_stamp(const struct sock *sk) -{ - u32 start_ts = tcp_sk(sk)->retrans_stamp; - - if (unlikely(!start_ts)) { - struct sk_buff *head = tcp_rtx_queue_head(sk); - - if (!head) - return 0; - start_ts = tcp_skb_timestamp(head); - } - return start_ts; -} - static u32 tcp_clamp_rto_to_user_timeout(const struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); u32 elapsed, start_ts; s32 remaining; - start_ts = tcp_retransmit_stamp(sk); - if (!icsk->icsk_user_timeout || !start_ts) + start_ts = tcp_sk(sk)->retrans_stamp; + if (!icsk->icsk_user_timeout) return icsk->icsk_rto; elapsed = tcp_time_stamp(tcp_sk(sk)) - start_ts; remaining = icsk->icsk_user_timeout - elapsed; @@ -173,7 +159,20 @@ static void tcp_mtu_probing(struct inet_connection_sock *icsk, struct sock *sk) tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); } - +static unsigned int tcp_model_timeout(struct sock *sk, + unsigned int boundary, + unsigned int rto_base) +{ + unsigned int linear_backoff_thresh, timeout; + + linear_backoff_thresh = ilog2(TCP_RTO_MAX / rto_base); + if (boundary <= linear_backoff_thresh) + timeout = ((2 << boundary) - 1) * rto_base; + else + timeout = ((2 << linear_backoff_thresh) - 1) * rto_base + + (boundary - linear_backoff_thresh) * TCP_RTO_MAX; + return jiffies_to_msecs(timeout); +} /** * retransmits_timed_out() - returns true if this connection has timed out * @sk: The current socket @@ -191,26 +190,15 @@ static bool retransmits_timed_out(struct sock *sk, unsigned int boundary, unsigned int timeout) { - const unsigned int rto_base = TCP_RTO_MIN; - unsigned int linear_backoff_thresh, start_ts; + unsigned int start_ts; if (!inet_csk(sk)->icsk_retransmits) return false; - start_ts = tcp_retransmit_stamp(sk); - if (!start_ts) - return false; - - if (likely(timeout == 0)) { - linear_backoff_thresh = ilog2(TCP_RTO_MAX/rto_base); + start_ts = tcp_sk(sk)->retrans_stamp; + if (likely(timeout == 0)) + timeout = tcp_model_timeout(sk, boundary, TCP_RTO_MIN); - if (boundary <= linear_backoff_thresh) - timeout = ((2 << boundary) - 1) * rto_base; - else - timeout = ((2 << linear_backoff_thresh) - 1) * rto_base + - (boundary - linear_backoff_thresh) * TCP_RTO_MAX; - timeout = jiffies_to_msecs(timeout); - } return (s32)(tcp_time_stamp(tcp_sk(sk)) - start_ts - timeout) >= 0; } @@ -289,14 +277,14 @@ void tcp_delack_timer_handler(struct sock *sk) icsk->icsk_ack.pending &= ~ICSK_ACK_TIMER; if (inet_csk_ack_scheduled(sk)) { - if (!icsk->icsk_ack.pingpong) { + if (!inet_csk_in_pingpong_mode(sk)) { /* Delayed ACK missed: inflate ATO. */ icsk->icsk_ack.ato = min(icsk->icsk_ack.ato << 1, icsk->icsk_rto); } else { /* Delayed ACK missed: leave pingpong mode and * deflate ATO. */ - icsk->icsk_ack.pingpong = 0; + inet_csk_exit_pingpong_mode(sk); icsk->icsk_ack.ato = TCP_ATO_MIN; } tcp_mstamp_refresh(tcp_sk(sk)); @@ -345,7 +333,6 @@ static void tcp_probe_timer(struct sock *sk) struct sk_buff *skb = tcp_send_head(sk); struct tcp_sock *tp = tcp_sk(sk); int max_probes; - u32 start_ts; if (tp->packets_out || !skb) { icsk->icsk_probes_out = 0; @@ -360,12 +347,13 @@ static void tcp_probe_timer(struct sock *sk) * corresponding system limit. We also implement similar policy when * we use RTO to probe window in tcp_retransmit_timer(). */ - start_ts = tcp_skb_timestamp(skb); - if (!start_ts) - skb->skb_mstamp_ns = tp->tcp_clock_cache; - else if (icsk->icsk_user_timeout && - (s32)(tcp_time_stamp(tp) - start_ts) > icsk->icsk_user_timeout) - goto abort; + if (icsk->icsk_user_timeout) { + u32 elapsed = tcp_model_timeout(sk, icsk->icsk_probes_out, + tcp_probe0_base(sk)); + + if (elapsed >= icsk->icsk_user_timeout) + goto abort; + } max_probes = sock_net(sk)->ipv4.sysctl_tcp_retries2; if (sock_flag(sk, SOCK_DEAD)) { @@ -395,6 +383,7 @@ static void tcp_fastopen_synack_timer(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); int max_retries = icsk->icsk_syn_retries ? : sock_net(sk)->ipv4.sysctl_tcp_synack_retries + 1; /* add one more retry for fastopen */ + struct tcp_sock *tp = tcp_sk(sk); struct request_sock *req; req = tcp_sk(sk)->fastopen_rsk; @@ -412,6 +401,8 @@ static void tcp_fastopen_synack_timer(struct sock *sk) inet_rtx_syn_ack(sk, req); req->num_timeout++; icsk->icsk_retransmits++; + if (!tp->retrans_stamp) + tp->retrans_stamp = tcp_time_stamp(tp); inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, TCP_TIMEOUT_INIT << req->num_timeout, TCP_RTO_MAX); } @@ -443,10 +434,8 @@ void tcp_retransmit_timer(struct sock *sk) */ return; } - if (!tp->packets_out) - goto out; - - WARN_ON(tcp_rtx_queue_empty(sk)); + if (!tp->packets_out || WARN_ON_ONCE(tcp_rtx_queue_empty(sk))) + return; tp->tlp_high_seq = 0; @@ -511,14 +500,13 @@ void tcp_retransmit_timer(struct sock *sk) tcp_enter_loss(sk); + icsk->icsk_retransmits++; if (tcp_retransmit_skb(sk, tcp_rtx_queue_head(sk), 1) > 0) { /* Retransmission failed because of local congestion, - * do not backoff. + * Let senders fight for local resources conservatively. */ - if (!icsk->icsk_retransmits) - icsk->icsk_retransmits = 1; inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - min(icsk->icsk_rto, TCP_RESOURCE_PROBE_INTERVAL), + TCP_RESOURCE_PROBE_INTERVAL, TCP_RTO_MAX); goto out; } @@ -539,7 +527,6 @@ void tcp_retransmit_timer(struct sock *sk) * the 120 second clamps though! */ icsk->icsk_backoff++; - icsk->icsk_retransmits++; out_reset_timer: /* If stream is thin, use linear timeouts. Since 'icsk_backoff' is diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 5c3cd5d84a6f..372fdc5381a9 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -562,10 +562,12 @@ static int __udp4_lib_err_encap_no_sk(struct sk_buff *skb, u32 info) for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) { int (*handler)(struct sk_buff *skb, u32 info); + const struct ip_tunnel_encap_ops *encap; - if (!iptun_encaps[i]) + encap = rcu_dereference(iptun_encaps[i]); + if (!encap) continue; - handler = rcu_dereference(iptun_encaps[i]->err_handler); + handler = encap->err_handler; if (handler && !handler(skb, info)) return 0; } diff --git a/net/ipv4/udp_tunnel.c b/net/ipv4/udp_tunnel.c index be8b5b2157d8..e93cc0379201 100644 --- a/net/ipv4/udp_tunnel.c +++ b/net/ipv4/udp_tunnel.c @@ -21,18 +21,9 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, goto error; if (cfg->bind_ifindex) { - struct net_device *dev; - - dev = dev_get_by_index(net, cfg->bind_ifindex); - if (!dev) { - err = -ENODEV; - goto error; - } - - err = kernel_setsockopt(sock, SOL_SOCKET, SO_BINDTODEVICE, - dev->name, strlen(dev->name) + 1); - dev_put(dev); - + err = kernel_setsockopt(sock, SOL_SOCKET, SO_BINDTOIFINDEX, + (void *)&cfg->bind_ifindex, + sizeof(cfg->bind_ifindex)); if (err < 0) goto error; } diff --git a/net/ipv6/addrconf.c b/net/ipv6/addrconf.c index 93d5ad2b1a69..4ae17a966ae3 100644 --- a/net/ipv6/addrconf.c +++ b/net/ipv6/addrconf.c @@ -597,6 +597,43 @@ static const struct nla_policy devconf_ipv6_policy[NETCONFA_MAX+1] = { [NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN] = { .len = sizeof(int) }, }; +static int inet6_netconf_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(struct netconfmsg), tb, + NETCONFA_MAX, devconf_ipv6_policy, extack); + + err = nlmsg_parse_strict(nlh, sizeof(struct netconfmsg), tb, + NETCONFA_MAX, devconf_ipv6_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETCONFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETCONFA_IFINDEX: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request"); + return -EINVAL; + } + } + + return 0; +} + static int inet6_netconf_get_devconf(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) @@ -605,14 +642,12 @@ static int inet6_netconf_get_devconf(struct sk_buff *in_skb, struct nlattr *tb[NETCONFA_MAX+1]; struct inet6_dev *in6_dev = NULL; struct net_device *dev = NULL; - struct netconfmsg *ncm; struct sk_buff *skb; struct ipv6_devconf *devconf; int ifindex; int err; - err = nlmsg_parse(nlh, sizeof(*ncm), tb, NETCONFA_MAX, - devconf_ipv6_policy, extack); + err = inet6_netconf_valid_get_req(in_skb, nlh, tb, extack); if (err < 0) return err; @@ -1165,7 +1200,8 @@ check_cleanup_prefix_route(struct inet6_ifaddr *ifp, unsigned long *expires) list_for_each_entry(ifa, &idev->addr_list, if_list) { if (ifa == ifp) continue; - if (!ipv6_prefix_equal(&ifa->addr, &ifp->addr, + if (ifa->prefix_len != ifp->prefix_len || + !ipv6_prefix_equal(&ifa->addr, &ifp->addr, ifp->prefix_len)) continue; if (ifa->flags & (IFA_F_PERMANENT | IFA_F_NOPREFIXROUTE)) @@ -3495,8 +3531,8 @@ static int addrconf_notify(struct notifier_block *this, unsigned long event, if (!addrconf_link_ready(dev)) { /* device is not ready yet. */ - pr_info("ADDRCONF(NETDEV_UP): %s: link is not ready\n", - dev->name); + pr_debug("ADDRCONF(NETDEV_UP): %s: link is not ready\n", + dev->name); break; } @@ -5120,6 +5156,8 @@ static int inet6_dump_addr(struct sk_buff *skb, struct netlink_callback *cb, if (idev) { err = in6_dump_addrs(idev, skb, cb, s_ip_idx, &fillargs); + if (err > 0) + err = 0; } goto put_tgt_net; } @@ -5179,6 +5217,52 @@ static int inet6_dump_ifacaddr(struct sk_buff *skb, struct netlink_callback *cb) return inet6_dump_addr(skb, cb, type); } +static int inet6_rtm_valid_getaddr_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifaddrmsg *ifm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for get address request"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get address request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + + err = nlmsg_parse_strict(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + if (err) + return err; + + for (i = 0; i <= IFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case IFA_TARGET_NETNSID: + case IFA_ADDRESS: + case IFA_LOCAL: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get address request"); + return -EINVAL; + } + } + + return 0; +} + static int inet6_rtm_getaddr(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -5199,8 +5283,7 @@ static int inet6_rtm_getaddr(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct sk_buff *skb; int err; - err = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_ipv6_policy, - extack); + err = inet6_rtm_valid_getaddr_req(in_skb, nlh, tb, extack); if (err < 0) return err; @@ -6822,6 +6905,12 @@ static int __net_init addrconf_init_net(struct net *net) if (!dflt) goto err_alloc_dflt; + if (IS_ENABLED(CONFIG_SYSCTL) && + sysctl_devconf_inherit_init_net == 1 && !net_eq(net, &init_net)) { + memcpy(all, init_net.ipv6.devconf_all, sizeof(ipv6_devconf)); + memcpy(dflt, init_net.ipv6.devconf_dflt, sizeof(ipv6_devconf_dflt)); + } + /* these will be inherited by all namespaces */ dflt->autoconf = ipv6_defaults.autoconf; dflt->disable_ipv6 = ipv6_defaults.disable_ipv6; diff --git a/net/ipv6/addrconf_core.c b/net/ipv6/addrconf_core.c index 5cd0029d930e..6c79af056d9b 100644 --- a/net/ipv6/addrconf_core.c +++ b/net/ipv6/addrconf_core.c @@ -134,6 +134,11 @@ static int eafnosupport_ipv6_dst_lookup(struct net *net, struct sock *u1, return -EAFNOSUPPORT; } +static int eafnosupport_ipv6_route_input(struct sk_buff *skb) +{ + return -EAFNOSUPPORT; +} + static struct fib6_table *eafnosupport_fib6_get_table(struct net *net, u32 id) { return NULL; @@ -170,6 +175,7 @@ eafnosupport_ip6_mtu_from_fib6(struct fib6_info *f6i, struct in6_addr *daddr, const struct ipv6_stub *ipv6_stub __read_mostly = &(struct ipv6_stub) { .ipv6_dst_lookup = eafnosupport_ipv6_dst_lookup, + .ipv6_route_input = eafnosupport_ipv6_route_input, .fib6_get_table = eafnosupport_fib6_get_table, .fib6_table_lookup = eafnosupport_fib6_table_lookup, .fib6_lookup = eafnosupport_fib6_lookup, diff --git a/net/ipv6/addrlabel.c b/net/ipv6/addrlabel.c index 0d1ee82ee55b..d43d076c98f5 100644 --- a/net/ipv6/addrlabel.c +++ b/net/ipv6/addrlabel.c @@ -523,6 +523,50 @@ static inline int ip6addrlbl_msgsize(void) + nla_total_size(4); /* IFAL_LABEL */ } +static int ip6addrlbl_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifaddrlblmsg *ifal; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifal))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for addrlabel get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*ifal), tb, IFAL_MAX, + ifal_policy, extack); + + ifal = nlmsg_data(nlh); + if (ifal->__ifal_reserved || ifal->ifal_flags || ifal->ifal_seq) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for addrlabel get request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(*ifal), tb, IFAL_MAX, + ifal_policy, extack); + if (err) + return err; + + for (i = 0; i <= IFAL_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case IFAL_ADDRESS: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in addrlabel get request"); + return -EINVAL; + } + } + + return 0; +} + static int ip6addrlbl_get(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -535,8 +579,7 @@ static int ip6addrlbl_get(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct ip6addrlbl_entry *p; struct sk_buff *skb; - err = nlmsg_parse(nlh, sizeof(*ifal), tb, IFAL_MAX, ifal_policy, - extack); + err = ip6addrlbl_valid_get_req(in_skb, nlh, tb, extack); if (err < 0) return err; diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index d99753b5e39b..2f45d2a3e3a3 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -900,10 +900,17 @@ static struct pernet_operations inet6_net_ops = { .exit = inet6_net_exit, }; +static int ipv6_route_input(struct sk_buff *skb) +{ + ip6_route_input(skb); + return skb_dst(skb)->error; +} + static const struct ipv6_stub ipv6_stub_impl = { .ipv6_sock_mc_join = ipv6_sock_mc_join, .ipv6_sock_mc_drop = ipv6_sock_mc_drop, .ipv6_dst_lookup = ip6_dst_lookup, + .ipv6_route_input = ipv6_route_input, .fib6_get_table = fib6_get_table, .fib6_table_lookup = fib6_table_lookup, .fib6_lookup = fib6_lookup, diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c index 5afe9f83374d..239d4a65ad6e 100644 --- a/net/ipv6/esp6.c +++ b/net/ipv6/esp6.c @@ -296,7 +296,7 @@ int esp6_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info skb->len += tailen; skb->data_len += tailen; skb->truesize += tailen; - if (sk) + if (sk && sk_fullsock(sk)) refcount_add(tailen, &sk->sk_wmem_alloc); goto out; diff --git a/net/ipv6/fou6.c b/net/ipv6/fou6.c index b858bd5280bf..ec4e2ed95f36 100644 --- a/net/ipv6/fou6.c +++ b/net/ipv6/fou6.c @@ -72,7 +72,7 @@ static int gue6_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, static int gue6_err_proto_handler(int proto, struct sk_buff *skb, struct inet6_skb_parm *opt, - u8 type, u8 code, int offset, u32 info) + u8 type, u8 code, int offset, __be32 info) { const struct inet6_protocol *ipprot; @@ -94,7 +94,7 @@ static int gue6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, int ret; len = sizeof(struct udphdr) + sizeof(struct guehdr); - if (!pskb_may_pull(skb, len)) + if (!pskb_may_pull(skb, transport_offset + len)) return -EINVAL; guehdr = (struct guehdr *)&udp_hdr(skb)[1]; @@ -129,7 +129,7 @@ static int gue6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, optlen = guehdr->hlen << 2; - if (!pskb_may_pull(skb, len + optlen)) + if (!pskb_may_pull(skb, transport_offset + len + optlen)) return -EINVAL; guehdr = (struct guehdr *)&udp_hdr(skb)[1]; diff --git a/net/ipv6/icmp.c b/net/ipv6/icmp.c index bbcdfd299692..802faa2fcc0e 100644 --- a/net/ipv6/icmp.c +++ b/net/ipv6/icmp.c @@ -81,7 +81,7 @@ */ static inline struct sock *icmpv6_sk(struct net *net) { - return net->ipv6.icmp_sk[smp_processor_id()]; + return *this_cpu_ptr(net->ipv6.icmp_sk); } static int icmpv6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, @@ -953,13 +953,21 @@ void icmpv6_flow_init(struct sock *sk, struct flowi6 *fl6, security_sk_classify_flow(sk, flowi6_to_flowi(fl6)); } +static void __net_exit icmpv6_sk_exit(struct net *net) +{ + int i; + + for_each_possible_cpu(i) + inet_ctl_sock_destroy(*per_cpu_ptr(net->ipv6.icmp_sk, i)); + free_percpu(net->ipv6.icmp_sk); +} + static int __net_init icmpv6_sk_init(struct net *net) { struct sock *sk; - int err, i, j; + int err, i; - net->ipv6.icmp_sk = - kcalloc(nr_cpu_ids, sizeof(struct sock *), GFP_KERNEL); + net->ipv6.icmp_sk = alloc_percpu(struct sock *); if (!net->ipv6.icmp_sk) return -ENOMEM; @@ -972,7 +980,7 @@ static int __net_init icmpv6_sk_init(struct net *net) goto fail; } - net->ipv6.icmp_sk[i] = sk; + *per_cpu_ptr(net->ipv6.icmp_sk, i) = sk; /* Enough space for 2 64K ICMP packets, including * sk_buff struct overhead. @@ -982,22 +990,10 @@ static int __net_init icmpv6_sk_init(struct net *net) return 0; fail: - for (j = 0; j < i; j++) - inet_ctl_sock_destroy(net->ipv6.icmp_sk[j]); - kfree(net->ipv6.icmp_sk); + icmpv6_sk_exit(net); return err; } -static void __net_exit icmpv6_sk_exit(struct net *net) -{ - int i; - - for_each_possible_cpu(i) { - inet_ctl_sock_destroy(net->ipv6.icmp_sk[i]); - } - kfree(net->ipv6.icmp_sk); -} - static struct pernet_operations icmpv6_sk_ops = { .init = icmpv6_sk_init, .exit = icmpv6_sk_exit, diff --git a/net/ipv6/ila/ila_xlat.c b/net/ipv6/ila/ila_xlat.c index 17c455ff69ff..79d2e43c05c5 100644 --- a/net/ipv6/ila/ila_xlat.c +++ b/net/ipv6/ila/ila_xlat.c @@ -383,12 +383,9 @@ int ila_xlat_nl_cmd_flush(struct sk_buff *skb, struct genl_info *info) struct rhashtable_iter iter; struct ila_map *ila; spinlock_t *lock; - int ret; - - ret = rhashtable_walk_init(&ilan->xlat.rhash_table, &iter, GFP_KERNEL); - if (ret) - goto done; + int ret = 0; + rhashtable_walk_enter(&ilan->xlat.rhash_table, &iter); rhashtable_walk_start(&iter); for (;;) { @@ -509,23 +506,17 @@ int ila_xlat_nl_dump_start(struct netlink_callback *cb) struct net *net = sock_net(cb->skb->sk); struct ila_net *ilan = net_generic(net, ila_net_id); struct ila_dump_iter *iter; - int ret; iter = kmalloc(sizeof(*iter), GFP_KERNEL); if (!iter) return -ENOMEM; - ret = rhashtable_walk_init(&ilan->xlat.rhash_table, &iter->rhiter, - GFP_KERNEL); - if (ret) { - kfree(iter); - return ret; - } + rhashtable_walk_enter(&ilan->xlat.rhash_table, &iter->rhiter); iter->skip = 0; cb->args[0] = (long)iter; - return ret; + return 0; } int ila_xlat_nl_dump_done(struct netlink_callback *cb) diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c index b1be67ca6768..b32c95f02128 100644 --- a/net/ipv6/ip6_gre.c +++ b/net/ipv6/ip6_gre.c @@ -524,7 +524,7 @@ static int ip6gre_rcv(struct sk_buff *skb, const struct tnl_ptk_info *tpi) return PACKET_REJECT; } -static int ip6erspan_rcv(struct sk_buff *skb, int gre_hdr_len, +static int ip6erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi) { struct erspan_base_hdr *ershdr; @@ -534,13 +534,9 @@ static int ip6erspan_rcv(struct sk_buff *skb, int gre_hdr_len, struct ip6_tnl *tunnel; u8 ver; - if (unlikely(!pskb_may_pull(skb, sizeof(*ershdr)))) - return PACKET_REJECT; - ipv6h = ipv6_hdr(skb); ershdr = (struct erspan_base_hdr *)skb->data; ver = ershdr->ver; - tpi->key = cpu_to_be32(get_session_id(ershdr)); tunnel = ip6gre_tunnel_lookup(skb->dev, &ipv6h->saddr, &ipv6h->daddr, tpi->key, @@ -611,7 +607,7 @@ static int gre_rcv(struct sk_buff *skb) if (unlikely(tpi.proto == htons(ETH_P_ERSPAN) || tpi.proto == htons(ETH_P_ERSPAN2))) { - if (ip6erspan_rcv(skb, hdr_len, &tpi) == PACKET_RCVD) + if (ip6erspan_rcv(skb, &tpi) == PACKET_RCVD) return 0; goto out; } @@ -1723,6 +1719,27 @@ static int ip6erspan_tap_validate(struct nlattr *tb[], struct nlattr *data[], return 0; } +static void ip6erspan_set_version(struct nlattr *data[], + struct __ip6_tnl_parm *parms) +{ + if (!data) + return; + + parms->erspan_ver = 1; + if (data[IFLA_GRE_ERSPAN_VER]) + parms->erspan_ver = nla_get_u8(data[IFLA_GRE_ERSPAN_VER]); + + if (parms->erspan_ver == 1) { + if (data[IFLA_GRE_ERSPAN_INDEX]) + parms->index = nla_get_u32(data[IFLA_GRE_ERSPAN_INDEX]); + } else if (parms->erspan_ver == 2) { + if (data[IFLA_GRE_ERSPAN_DIR]) + parms->dir = nla_get_u8(data[IFLA_GRE_ERSPAN_DIR]); + if (data[IFLA_GRE_ERSPAN_HWID]) + parms->hwid = nla_get_u16(data[IFLA_GRE_ERSPAN_HWID]); + } +} + static void ip6gre_netlink_parms(struct nlattr *data[], struct __ip6_tnl_parm *parms) { @@ -1771,20 +1788,6 @@ static void ip6gre_netlink_parms(struct nlattr *data[], if (data[IFLA_GRE_COLLECT_METADATA]) parms->collect_md = true; - - parms->erspan_ver = 1; - if (data[IFLA_GRE_ERSPAN_VER]) - parms->erspan_ver = nla_get_u8(data[IFLA_GRE_ERSPAN_VER]); - - if (parms->erspan_ver == 1) { - if (data[IFLA_GRE_ERSPAN_INDEX]) - parms->index = nla_get_u32(data[IFLA_GRE_ERSPAN_INDEX]); - } else if (parms->erspan_ver == 2) { - if (data[IFLA_GRE_ERSPAN_DIR]) - parms->dir = nla_get_u8(data[IFLA_GRE_ERSPAN_DIR]); - if (data[IFLA_GRE_ERSPAN_HWID]) - parms->hwid = nla_get_u16(data[IFLA_GRE_ERSPAN_HWID]); - } } static int ip6gre_tap_init(struct net_device *dev) @@ -2102,12 +2105,31 @@ static int ip6gre_fill_info(struct sk_buff *skb, const struct net_device *dev) { struct ip6_tnl *t = netdev_priv(dev); struct __ip6_tnl_parm *p = &t->parms; + __be16 o_flags = p->o_flags; + + if (p->erspan_ver == 1 || p->erspan_ver == 2) { + if (!p->collect_md) + o_flags |= TUNNEL_KEY; + + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, p->erspan_ver)) + goto nla_put_failure; + + if (p->erspan_ver == 1) { + if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, p->index)) + goto nla_put_failure; + } else { + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, p->dir)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, p->hwid)) + goto nla_put_failure; + } + } if (nla_put_u32(skb, IFLA_GRE_LINK, p->link) || nla_put_be16(skb, IFLA_GRE_IFLAGS, gre_tnl_flags_to_gre_flags(p->i_flags)) || nla_put_be16(skb, IFLA_GRE_OFLAGS, - gre_tnl_flags_to_gre_flags(p->o_flags)) || + gre_tnl_flags_to_gre_flags(o_flags)) || nla_put_be32(skb, IFLA_GRE_IKEY, p->i_key) || nla_put_be32(skb, IFLA_GRE_OKEY, p->o_key) || nla_put_in6_addr(skb, IFLA_GRE_LOCAL, &p->laddr) || @@ -2116,8 +2138,7 @@ static int ip6gre_fill_info(struct sk_buff *skb, const struct net_device *dev) nla_put_u8(skb, IFLA_GRE_ENCAP_LIMIT, p->encap_limit) || nla_put_be32(skb, IFLA_GRE_FLOWINFO, p->flowinfo) || nla_put_u32(skb, IFLA_GRE_FLAGS, p->flags) || - nla_put_u32(skb, IFLA_GRE_FWMARK, p->fwmark) || - nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, p->index)) + nla_put_u32(skb, IFLA_GRE_FWMARK, p->fwmark)) goto nla_put_failure; if (nla_put_u16(skb, IFLA_GRE_ENCAP_TYPE, @@ -2135,19 +2156,6 @@ static int ip6gre_fill_info(struct sk_buff *skb, const struct net_device *dev) goto nla_put_failure; } - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, p->erspan_ver)) - goto nla_put_failure; - - if (p->erspan_ver == 1) { - if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, p->index)) - goto nla_put_failure; - } else if (p->erspan_ver == 2) { - if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, p->dir)) - goto nla_put_failure; - if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, p->hwid)) - goto nla_put_failure; - } - return 0; nla_put_failure: @@ -2202,6 +2210,7 @@ static int ip6erspan_newlink(struct net *src_net, struct net_device *dev, int err; ip6gre_netlink_parms(data, &nt->parms); + ip6erspan_set_version(data, &nt->parms); ign = net_generic(net, ip6gre_net_id); if (nt->parms.collect_md) { @@ -2247,6 +2256,7 @@ static int ip6erspan_changelink(struct net_device *dev, struct nlattr *tb[], if (IS_ERR(t)) return PTR_ERR(t); + ip6erspan_set_version(data, &p); ip6gre_tunnel_unlink_md(ign, t); ip6gre_tunnel_unlink(ign, t); ip6erspan_tnl_change(t, &p, !tb[IFLA_MTU]); diff --git a/net/ipv6/ip6_offload.c b/net/ipv6/ip6_offload.c index 5c045691c302..345882d9c061 100644 --- a/net/ipv6/ip6_offload.c +++ b/net/ipv6/ip6_offload.c @@ -383,9 +383,36 @@ static struct packet_offload ipv6_packet_offload __read_mostly = { }, }; +static struct sk_buff *sit_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP4)) + return ERR_PTR(-EINVAL); + + return ipv6_gso_segment(skb, features); +} + +static struct sk_buff *ip4ip6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP6)) + return ERR_PTR(-EINVAL); + + return inet_gso_segment(skb, features); +} + +static struct sk_buff *ip6ip6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP6)) + return ERR_PTR(-EINVAL); + + return ipv6_gso_segment(skb, features); +} + static const struct net_offload sit_offload = { .callbacks = { - .gso_segment = ipv6_gso_segment, + .gso_segment = sit_gso_segment, .gro_receive = sit_ip6ip6_gro_receive, .gro_complete = sit_gro_complete, }, @@ -393,7 +420,7 @@ static const struct net_offload sit_offload = { static const struct net_offload ip4ip6_offload = { .callbacks = { - .gso_segment = inet_gso_segment, + .gso_segment = ip4ip6_gso_segment, .gro_receive = ip4ip6_gro_receive, .gro_complete = ip4ip6_gro_complete, }, @@ -401,7 +428,7 @@ static const struct net_offload ip4ip6_offload = { static const struct net_offload ip6ip6_offload = { .callbacks = { - .gso_segment = ipv6_gso_segment, + .gso_segment = ip6ip6_gso_segment, .gro_receive = sit_ip6ip6_gro_receive, .gro_complete = ip6ip6_gro_complete, }, diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c index 5f9fa0302b5a..edbd12067170 100644 --- a/net/ipv6/ip6_output.c +++ b/net/ipv6/ip6_output.c @@ -300,6 +300,12 @@ static int ip6_call_ra_chain(struct sk_buff *skb, int sel) if (sk && ra->sel == sel && (!sk->sk_bound_dev_if || sk->sk_bound_dev_if == skb->dev->ifindex)) { + struct ipv6_pinfo *np = inet6_sk(sk); + + if (np && np->rtalert_isolate && + !net_eq(sock_net(sk), dev_net(skb->dev))) { + continue; + } if (last) { struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); if (skb2) diff --git a/net/ipv6/ip6_udp_tunnel.c b/net/ipv6/ip6_udp_tunnel.c index ad1a9ccd4b44..25430c991cea 100644 --- a/net/ipv6/ip6_udp_tunnel.c +++ b/net/ipv6/ip6_udp_tunnel.c @@ -32,18 +32,9 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg, goto error; } if (cfg->bind_ifindex) { - struct net_device *dev; - - dev = dev_get_by_index(net, cfg->bind_ifindex); - if (!dev) { - err = -ENODEV; - goto error; - } - - err = kernel_setsockopt(sock, SOL_SOCKET, SO_BINDTODEVICE, - dev->name, strlen(dev->name) + 1); - dev_put(dev); - + err = kernel_setsockopt(sock, SOL_SOCKET, SO_BINDTOIFINDEX, + (void *)&cfg->bind_ifindex, + sizeof(cfg->bind_ifindex)); if (err < 0) goto error; } diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index 30337b38274b..e4dd57976737 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -97,7 +97,7 @@ static void mr6_netlink_event(struct mr_table *mrt, struct mfc6_cache *mfc, static void mrt6msg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt); static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb); -static void mroute_clean_tables(struct mr_table *mrt, bool all); +static void mroute_clean_tables(struct mr_table *mrt, int flags); static void ipmr_expire_process(struct timer_list *t); #ifdef CONFIG_IPV6_MROUTE_MULTIPLE_TABLES @@ -393,7 +393,8 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id) static void ip6mr_free_table(struct mr_table *mrt) { del_timer_sync(&mrt->ipmr_expire_timer); - mroute_clean_tables(mrt, true); + mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC | + MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC); rhltable_destroy(&mrt->mfc_hash); kfree(mrt); } @@ -1496,43 +1497,51 @@ static int ip6mr_mfc_add(struct net *net, struct mr_table *mrt, * Close the multicast socket, and clear the vif tables etc */ -static void mroute_clean_tables(struct mr_table *mrt, bool all) +static void mroute_clean_tables(struct mr_table *mrt, int flags) { struct mr_mfc *c, *tmp; LIST_HEAD(list); int i; /* Shut down all active vif entries */ - for (i = 0; i < mrt->maxvif; i++) { - if (!all && (mrt->vif_table[i].flags & VIFF_STATIC)) - continue; - mif6_delete(mrt, i, 0, &list); + if (flags & (MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC)) { + for (i = 0; i < mrt->maxvif; i++) { + if (((mrt->vif_table[i].flags & VIFF_STATIC) && + !(flags & MRT6_FLUSH_MIFS_STATIC)) || + (!(mrt->vif_table[i].flags & VIFF_STATIC) && !(flags & MRT6_FLUSH_MIFS))) + continue; + mif6_delete(mrt, i, 0, &list); + } + unregister_netdevice_many(&list); } - unregister_netdevice_many(&list); /* Wipe the cache */ - list_for_each_entry_safe(c, tmp, &mrt->mfc_cache_list, list) { - if (!all && (c->mfc_flags & MFC_STATIC)) - continue; - rhltable_remove(&mrt->mfc_hash, &c->mnode, ip6mr_rht_params); - list_del_rcu(&c->list); - mr6_netlink_event(mrt, (struct mfc6_cache *)c, RTM_DELROUTE); - mr_cache_put(c); - } - - if (atomic_read(&mrt->cache_resolve_queue_len) != 0) { - spin_lock_bh(&mfc_unres_lock); - list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) { - list_del(&c->list); + if (flags & (MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC)) { + list_for_each_entry_safe(c, tmp, &mrt->mfc_cache_list, list) { + if (((c->mfc_flags & MFC_STATIC) && !(flags & MRT6_FLUSH_MFC_STATIC)) || + (!(c->mfc_flags & MFC_STATIC) && !(flags & MRT6_FLUSH_MFC))) + continue; + rhltable_remove(&mrt->mfc_hash, &c->mnode, ip6mr_rht_params); + list_del_rcu(&c->list); call_ip6mr_mfc_entry_notifiers(read_pnet(&mrt->net), FIB_EVENT_ENTRY_DEL, - (struct mfc6_cache *)c, - mrt->id); - mr6_netlink_event(mrt, (struct mfc6_cache *)c, - RTM_DELROUTE); - ip6mr_destroy_unres(mrt, (struct mfc6_cache *)c); + (struct mfc6_cache *)c, mrt->id); + mr6_netlink_event(mrt, (struct mfc6_cache *)c, RTM_DELROUTE); + mr_cache_put(c); + } + } + + if (flags & MRT6_FLUSH_MFC) { + if (atomic_read(&mrt->cache_resolve_queue_len) != 0) { + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) { + list_del(&c->list); + mr6_netlink_event(mrt, (struct mfc6_cache *)c, + RTM_DELROUTE); + ip6mr_destroy_unres(mrt, (struct mfc6_cache *)c); + } + spin_unlock_bh(&mfc_unres_lock); } - spin_unlock_bh(&mfc_unres_lock); } } @@ -1588,7 +1597,7 @@ int ip6mr_sk_done(struct sock *sk) NETCONFA_IFINDEX_ALL, net->ipv6.devconf_all); - mroute_clean_tables(mrt, false); + mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MFC); err = 0; break; } @@ -1704,6 +1713,20 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns rtnl_unlock(); return ret; + case MRT6_FLUSH: + { + int flags; + + if (optlen != sizeof(flags)) + return -EINVAL; + if (get_user(flags, (int __user *)optval)) + return -EFAULT; + rtnl_lock(); + mroute_clean_tables(mrt, flags); + rtnl_unlock(); + return 0; + } + /* * Control PIM assert (to activate pim will activate assert) */ @@ -1965,10 +1988,10 @@ int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) static inline int ip6mr_forward2_finish(struct net *net, struct sock *sk, struct sk_buff *skb) { - __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), - IPSTATS_MIB_OUTFORWDATAGRAMS); - __IP6_ADD_STATS(net, ip6_dst_idev(skb_dst(skb)), - IPSTATS_MIB_OUTOCTETS, skb->len); + IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTFORWDATAGRAMS); + IP6_ADD_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTOCTETS, skb->len); return dst_output(net, sk, skb); } diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c index 973e215c3114..40f21fef25ff 100644 --- a/net/ipv6/ipv6_sockglue.c +++ b/net/ipv6/ipv6_sockglue.c @@ -787,6 +787,12 @@ done: goto e_inval; retv = ip6_ra_control(sk, val); break; + case IPV6_ROUTER_ALERT_ISOLATE: + if (optlen < sizeof(int)) + goto e_inval; + np->rtalert_isolate = valbool; + retv = 0; + break; case IPV6_MTU_DISCOVER: if (optlen < sizeof(int)) goto e_inval; @@ -1358,6 +1364,10 @@ static int do_ipv6_getsockopt(struct sock *sk, int level, int optname, val = np->rxopt.bits.recvfragsize; break; + case IPV6_ROUTER_ALERT_ISOLATE: + val = np->rtalert_isolate; + break; + default: return -ENOPROTOOPT; } diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c index 21f6deb2aec9..42f3f5cd349f 100644 --- a/net/ipv6/mcast.c +++ b/net/ipv6/mcast.c @@ -940,6 +940,7 @@ int ipv6_dev_mc_inc(struct net_device *dev, const struct in6_addr *addr) { return __ipv6_dev_mc_inc(dev, addr, MCAST_EXCLUDE); } +EXPORT_SYMBOL(ipv6_dev_mc_inc); /* * device multicast group del @@ -987,6 +988,7 @@ int ipv6_dev_mc_dec(struct net_device *dev, const struct in6_addr *addr) return err; } +EXPORT_SYMBOL(ipv6_dev_mc_dec); /* * check if the interface/address pair is valid diff --git a/net/ipv6/mcast_snoop.c b/net/ipv6/mcast_snoop.c index 9405b04eecc6..dddd75d1be0e 100644 --- a/net/ipv6/mcast_snoop.c +++ b/net/ipv6/mcast_snoop.c @@ -41,6 +41,8 @@ static int ipv6_mc_check_ip6hdr(struct sk_buff *skb) if (skb->len < len || len <= offset) return -EINVAL; + skb_set_transport_header(skb, offset); + return 0; } @@ -77,27 +79,27 @@ static int ipv6_mc_check_mld_reportv2(struct sk_buff *skb) len += sizeof(struct mld2_report); - return pskb_may_pull(skb, len) ? 0 : -EINVAL; + return ipv6_mc_may_pull(skb, len) ? 0 : -EINVAL; } static int ipv6_mc_check_mld_query(struct sk_buff *skb) { + unsigned int transport_len = ipv6_transport_len(skb); struct mld_msg *mld; - unsigned int len = skb_transport_offset(skb); + unsigned int len; /* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */ if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) return -EINVAL; - len += sizeof(struct mld_msg); - if (skb->len < len) - return -EINVAL; - /* MLDv1? */ - if (skb->len != len) { + if (transport_len != sizeof(struct mld_msg)) { /* or MLDv2? */ - len += sizeof(struct mld2_query) - sizeof(struct mld_msg); - if (skb->len < len || !pskb_may_pull(skb, len)) + if (transport_len < sizeof(struct mld2_query)) + return -EINVAL; + + len = skb_transport_offset(skb) + sizeof(struct mld2_query); + if (!ipv6_mc_may_pull(skb, len)) return -EINVAL; } @@ -115,12 +117,17 @@ static int ipv6_mc_check_mld_query(struct sk_buff *skb) static int ipv6_mc_check_mld_msg(struct sk_buff *skb) { - struct mld_msg *mld = (struct mld_msg *)skb_transport_header(skb); + unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg); + struct mld_msg *mld; + + if (!ipv6_mc_may_pull(skb, len)) + return -EINVAL; + + mld = (struct mld_msg *)skb_transport_header(skb); switch (mld->mld_type) { case ICMPV6_MGM_REDUCTION: case ICMPV6_MGM_REPORT: - /* fall through */ return 0; case ICMPV6_MLD2_REPORT: return ipv6_mc_check_mld_reportv2(skb); @@ -136,49 +143,30 @@ static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb) return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo); } -static int __ipv6_mc_check_mld(struct sk_buff *skb, - struct sk_buff **skb_trimmed) - +int ipv6_mc_check_icmpv6(struct sk_buff *skb) { - struct sk_buff *skb_chk = NULL; - unsigned int transport_len; - unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg); - int ret = -EINVAL; + unsigned int len = skb_transport_offset(skb) + sizeof(struct icmp6hdr); + unsigned int transport_len = ipv6_transport_len(skb); + struct sk_buff *skb_chk; - transport_len = ntohs(ipv6_hdr(skb)->payload_len); - transport_len -= skb_transport_offset(skb) - sizeof(struct ipv6hdr); + if (!ipv6_mc_may_pull(skb, len)) + return -EINVAL; skb_chk = skb_checksum_trimmed(skb, transport_len, ipv6_mc_validate_checksum); if (!skb_chk) - goto err; - - if (!pskb_may_pull(skb_chk, len)) - goto err; - - ret = ipv6_mc_check_mld_msg(skb_chk); - if (ret) - goto err; - - if (skb_trimmed) - *skb_trimmed = skb_chk; - /* free now unneeded clone */ - else if (skb_chk != skb) - kfree_skb(skb_chk); - - ret = 0; + return -EINVAL; -err: - if (ret && skb_chk && skb_chk != skb) + if (skb_chk != skb) kfree_skb(skb_chk); - return ret; + return 0; } +EXPORT_SYMBOL(ipv6_mc_check_icmpv6); /** * ipv6_mc_check_mld - checks whether this is a sane MLD packet * @skb: the skb to validate - * @skb_trimmed: to store an skb pointer trimmed to IPv6 packet tail (optional) * * Checks whether an IPv6 packet is a valid MLD packet. If so sets * skb transport header accordingly and returns zero. @@ -188,18 +176,10 @@ err: * -ENOMSG: IP header validation succeeded but it is not an MLD packet. * -ENOMEM: A memory allocation failure happened. * - * Optionally, an skb pointer might be provided via skb_trimmed (or set it - * to NULL): After parsing an MLD packet successfully it will point to - * an skb which has its tail aligned to the IP packet end. This might - * either be the originally provided skb or a trimmed, cloned version if - * the skb frame had data beyond the IP packet. A cloned skb allows us - * to leave the original skb and its full frame unchanged (which might be - * desirable for layer 2 frame jugglers). - * * Caller needs to set the skb network header and free any returned skb if it * differs from the provided skb. */ -int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed) +int ipv6_mc_check_mld(struct sk_buff *skb) { int ret; @@ -211,6 +191,10 @@ int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed) if (ret < 0) return ret; - return __ipv6_mc_check_mld(skb, skb_trimmed); + ret = ipv6_mc_check_icmpv6(skb); + if (ret < 0) + return ret; + + return ipv6_mc_check_mld_msg(skb); } EXPORT_SYMBOL(ipv6_mc_check_mld); diff --git a/net/ipv6/netfilter.c b/net/ipv6/netfilter.c index 8b075f0bc351..1240ccd57f39 100644 --- a/net/ipv6/netfilter.c +++ b/net/ipv6/netfilter.c @@ -23,9 +23,11 @@ int ip6_route_me_harder(struct net *net, struct sk_buff *skb) struct sock *sk = sk_to_full_sk(skb->sk); unsigned int hh_len; struct dst_entry *dst; + int strict = (ipv6_addr_type(&iph->daddr) & + (IPV6_ADDR_MULTICAST | IPV6_ADDR_LINKLOCAL)); struct flowi6 fl6 = { .flowi6_oif = sk && sk->sk_bound_dev_if ? sk->sk_bound_dev_if : - rt6_need_strict(&iph->daddr) ? skb_dst(skb)->dev->ifindex : 0, + strict ? skb_dst(skb)->dev->ifindex : 0, .flowi6_mark = skb->mark, .flowi6_uid = sock_net_uid(net, sk), .daddr = iph->daddr, @@ -84,8 +86,8 @@ static int nf_ip6_reroute(struct sk_buff *skb, return 0; } -static int nf_ip6_route(struct net *net, struct dst_entry **dst, - struct flowi *fl, bool strict) +int __nf_ip6_route(struct net *net, struct dst_entry **dst, + struct flowi *fl, bool strict) { static const struct ipv6_pinfo fake_pinfo; static const struct inet_sock fake_sk = { @@ -105,12 +107,17 @@ static int nf_ip6_route(struct net *net, struct dst_entry **dst, *dst = result; return err; } +EXPORT_SYMBOL_GPL(__nf_ip6_route); static const struct nf_ipv6_ops ipv6ops = { +#if IS_MODULE(CONFIG_IPV6) .chk_addr = ipv6_chk_addr, - .route_input = ip6_route_input, + .route_me_harder = ip6_route_me_harder, + .dev_get_saddr = ipv6_dev_get_saddr, + .route = __nf_ip6_route, +#endif + .route_input = ip6_route_input, .fragment = ip6_fragment, - .route = nf_ip6_route, .reroute = nf_ip6_reroute, }; diff --git a/net/ipv6/netfilter/Kconfig b/net/ipv6/netfilter/Kconfig index 339d0762b027..ddc99a1653aa 100644 --- a/net/ipv6/netfilter/Kconfig +++ b/net/ipv6/netfilter/Kconfig @@ -31,34 +31,6 @@ config NFT_CHAIN_ROUTE_IPV6 fields such as the source, destination, flowlabel, hop-limit and the packet mark. -if NF_NAT_IPV6 - -config NFT_CHAIN_NAT_IPV6 - tristate "IPv6 nf_tables nat chain support" - help - This option enables the "nat" chain for IPv6 in nf_tables. This - chain type is used to perform Network Address Translation (NAT) - packet transformations such as the source, destination address and - source and destination ports. - -config NFT_MASQ_IPV6 - tristate "IPv6 masquerade support for nf_tables" - depends on NFT_MASQ - select NF_NAT_MASQUERADE_IPV6 - help - This is the expression that provides IPv4 masquerading support for - nf_tables. - -config NFT_REDIR_IPV6 - tristate "IPv6 redirect support for nf_tables" - depends on NFT_REDIR - select NF_NAT_REDIRECT - help - This is the expression that provides IPv4 redirect support for - nf_tables. - -endif # NF_NAT_IPV6 - config NFT_REJECT_IPV6 select NF_REJECT_IPV6 default NFT_REJECT @@ -106,23 +78,6 @@ config NF_LOG_IPV6 default m if NETFILTER_ADVANCED=n select NF_LOG_COMMON -config NF_NAT_IPV6 - tristate "IPv6 NAT" - depends on NF_CONNTRACK - depends on NETFILTER_ADVANCED - select NF_NAT - help - The IPv6 NAT option allows masquerading, port forwarding and other - forms of full Network Address Port Translation. This can be - controlled by iptables or nft. - -if NF_NAT_IPV6 - -config NF_NAT_MASQUERADE_IPV6 - bool - -endif # NF_NAT_IPV6 - config IP6_NF_IPTABLES tristate "IP6 tables support (required for filtering)" depends on INET && IPV6 @@ -311,7 +266,6 @@ config IP6_NF_NAT depends on NF_CONNTRACK depends on NETFILTER_ADVANCED select NF_NAT - select NF_NAT_IPV6 select NETFILTER_XT_NAT help This enables the `nat' table in ip6tables. This allows masquerading, @@ -324,7 +278,7 @@ if IP6_NF_NAT config IP6_NF_TARGET_MASQUERADE tristate "MASQUERADE target support" - select NF_NAT_MASQUERADE_IPV6 + select NF_NAT_MASQUERADE help Masquerading is a special case of NAT: all outgoing connections are changed to seem to come from a particular interface's address, and diff --git a/net/ipv6/netfilter/Makefile b/net/ipv6/netfilter/Makefile index 9ea43d5256e0..3853c648ebaa 100644 --- a/net/ipv6/netfilter/Makefile +++ b/net/ipv6/netfilter/Makefile @@ -11,10 +11,6 @@ obj-$(CONFIG_IP6_NF_RAW) += ip6table_raw.o obj-$(CONFIG_IP6_NF_SECURITY) += ip6table_security.o obj-$(CONFIG_IP6_NF_NAT) += ip6table_nat.o -nf_nat_ipv6-y := nf_nat_l3proto_ipv6.o -nf_nat_ipv6-$(CONFIG_NF_NAT_MASQUERADE_IPV6) += nf_nat_masquerade_ipv6.o -obj-$(CONFIG_NF_NAT_IPV6) += nf_nat_ipv6.o - # defrag nf_defrag_ipv6-y := nf_defrag_ipv6_hooks.o nf_conntrack_reasm.o obj-$(CONFIG_NF_DEFRAG_IPV6) += nf_defrag_ipv6.o @@ -32,10 +28,7 @@ obj-$(CONFIG_NF_DUP_IPV6) += nf_dup_ipv6.o # nf_tables obj-$(CONFIG_NFT_CHAIN_ROUTE_IPV6) += nft_chain_route_ipv6.o -obj-$(CONFIG_NFT_CHAIN_NAT_IPV6) += nft_chain_nat_ipv6.o obj-$(CONFIG_NFT_REJECT_IPV6) += nft_reject_ipv6.o -obj-$(CONFIG_NFT_MASQ_IPV6) += nft_masq_ipv6.o -obj-$(CONFIG_NFT_REDIR_IPV6) += nft_redir_ipv6.o obj-$(CONFIG_NFT_DUP_IPV6) += nft_dup_ipv6.o obj-$(CONFIG_NFT_FIB_IPV6) += nft_fib_ipv6.o diff --git a/net/ipv6/netfilter/ip6table_nat.c b/net/ipv6/netfilter/ip6table_nat.c index 67ba70ab9f5c..3e1fab9d7503 100644 --- a/net/ipv6/netfilter/ip6table_nat.c +++ b/net/ipv6/netfilter/ip6table_nat.c @@ -17,8 +17,6 @@ #include <net/ipv6.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_nat_l3proto.h> static int __net_init ip6table_nat_table_init(struct net *net); @@ -72,10 +70,10 @@ static int ip6t_nat_register_lookups(struct net *net) int i, ret; for (i = 0; i < ARRAY_SIZE(nf_nat_ipv6_ops); i++) { - ret = nf_nat_l3proto_ipv6_register_fn(net, &nf_nat_ipv6_ops[i]); + ret = nf_nat_ipv6_register_fn(net, &nf_nat_ipv6_ops[i]); if (ret) { while (i) - nf_nat_l3proto_ipv6_unregister_fn(net, &nf_nat_ipv6_ops[--i]); + nf_nat_ipv6_unregister_fn(net, &nf_nat_ipv6_ops[--i]); return ret; } @@ -89,7 +87,7 @@ static void ip6t_nat_unregister_lookups(struct net *net) int i; for (i = 0; i < ARRAY_SIZE(nf_nat_ipv6_ops); i++) - nf_nat_l3proto_ipv6_unregister_fn(net, &nf_nat_ipv6_ops[i]); + nf_nat_ipv6_unregister_fn(net, &nf_nat_ipv6_ops[i]); } static int __net_init ip6table_nat_table_init(struct net *net) diff --git a/net/ipv6/netfilter/nf_conntrack_reasm.c b/net/ipv6/netfilter/nf_conntrack_reasm.c index 181da2c40f9a..3de0e9b0a482 100644 --- a/net/ipv6/netfilter/nf_conntrack_reasm.c +++ b/net/ipv6/netfilter/nf_conntrack_reasm.c @@ -136,6 +136,9 @@ static void __net_exit nf_ct_frags6_sysctl_unregister(struct net *net) } #endif +static int nf_ct_frag6_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev); + static inline u8 ip6_frag_ecn(const struct ipv6hdr *ipv6h) { return 1 << (ipv6_get_dsfield(ipv6h) & INET_ECN_MASK); @@ -177,9 +180,10 @@ static struct frag_queue *fq_find(struct net *net, __be32 id, u32 user, static int nf_ct_frag6_queue(struct frag_queue *fq, struct sk_buff *skb, const struct frag_hdr *fhdr, int nhoff) { - struct sk_buff *prev, *next; unsigned int payload_len; - int offset, end; + struct net_device *dev; + struct sk_buff *prev; + int offset, end, err; u8 ecn; if (fq->q.flags & INET_FRAG_COMPLETE) { @@ -254,55 +258,18 @@ static int nf_ct_frag6_queue(struct frag_queue *fq, struct sk_buff *skb, goto err; } - /* Find out which fragments are in front and at the back of us - * in the chain of fragments so far. We must know where to put - * this fragment, right? - */ - prev = fq->q.fragments_tail; - if (!prev || prev->ip_defrag_offset < offset) { - next = NULL; - goto found; - } - prev = NULL; - for (next = fq->q.fragments; next != NULL; next = next->next) { - if (next->ip_defrag_offset >= offset) - break; /* bingo! */ - prev = next; - } - -found: - /* RFC5722, Section 4: - * When reassembling an IPv6 datagram, if - * one or more its constituent fragments is determined to be an - * overlapping fragment, the entire datagram (and any constituent - * fragments, including those not yet received) MUST be silently - * discarded. - */ - - /* Check for overlap with preceding fragment. */ - if (prev && - (prev->ip_defrag_offset + prev->len) > offset) - goto discard_fq; - - /* Look for overlap with succeeding segment. */ - if (next && next->ip_defrag_offset < end) - goto discard_fq; - - /* Note : skb->ip_defrag_offset and skb->dev share the same location */ - if (skb->dev) - fq->iif = skb->dev->ifindex; + /* Note : skb->rbnode and skb->dev share the same location. */ + dev = skb->dev; /* Makes sure compiler wont do silly aliasing games */ barrier(); - skb->ip_defrag_offset = offset; - /* Insert this fragment in the chain of fragments. */ - skb->next = next; - if (!next) - fq->q.fragments_tail = skb; - if (prev) - prev->next = skb; - else - fq->q.fragments = skb; + prev = fq->q.fragments_tail; + err = inet_frag_queue_insert(&fq->q, skb, offset, end); + if (err) + goto insert_error; + + if (dev) + fq->iif = dev->ifindex; fq->q.stamp = skb->tstamp; fq->q.meat += skb->len; @@ -319,11 +286,25 @@ found: fq->q.flags |= INET_FRAG_FIRST_IN; } - return 0; + if (fq->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && + fq->q.meat == fq->q.len) { + unsigned long orefdst = skb->_skb_refdst; + + skb->_skb_refdst = 0UL; + err = nf_ct_frag6_reasm(fq, skb, prev, dev); + skb->_skb_refdst = orefdst; + return err; + } + + skb_dst_drop(skb); + return -EINPROGRESS; -discard_fq: +insert_error: + if (err == IPFRAG_DUP) + goto err; inet_frag_kill(&fq->q); err: + skb_dst_drop(skb); return -EINVAL; } @@ -333,147 +314,66 @@ err: * It is called with locked fq, and caller must check that * queue is eligible for reassembly i.e. it is not COMPLETE, * the last and the first frames arrived and all the bits are here. - * - * returns true if *prev skb has been transformed into the reassembled - * skb, false otherwise. */ -static bool -nf_ct_frag6_reasm(struct frag_queue *fq, struct sk_buff *prev, struct net_device *dev) +static int nf_ct_frag6_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev) { - struct sk_buff *fp, *head = fq->q.fragments; - int payload_len, delta; + void *reasm_data; + int payload_len; u8 ecn; inet_frag_kill(&fq->q); - WARN_ON(head == NULL); - WARN_ON(head->ip_defrag_offset != 0); - ecn = ip_frag_ecn_table[fq->ecn]; if (unlikely(ecn == 0xff)) - return false; + goto err; + + reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail); + if (!reasm_data) + goto err; - /* Unfragmented part is taken from the first segment. */ - payload_len = ((head->data - skb_network_header(head)) - + payload_len = ((skb->data - skb_network_header(skb)) - sizeof(struct ipv6hdr) + fq->q.len - sizeof(struct frag_hdr)); if (payload_len > IPV6_MAXPLEN) { net_dbg_ratelimited("nf_ct_frag6_reasm: payload len = %d\n", payload_len); - return false; - } - - delta = - head->truesize; - - /* Head of list must not be cloned. */ - if (skb_unclone(head, GFP_ATOMIC)) - return false; - - delta += head->truesize; - if (delta) - add_frag_mem_limit(fq->q.net, delta); - - /* If the first fragment is fragmented itself, we split - * it to two chunks: the first with data and paged part - * and the second, holding only fragments. */ - if (skb_has_frag_list(head)) { - struct sk_buff *clone; - int i, plen = 0; - - clone = alloc_skb(0, GFP_ATOMIC); - if (clone == NULL) - return false; - - clone->next = head->next; - head->next = clone; - skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; - skb_frag_list_init(head); - for (i = 0; i < skb_shinfo(head)->nr_frags; i++) - plen += skb_frag_size(&skb_shinfo(head)->frags[i]); - clone->len = clone->data_len = head->data_len - plen; - head->data_len -= clone->len; - head->len -= clone->len; - clone->csum = 0; - clone->ip_summed = head->ip_summed; - - add_frag_mem_limit(fq->q.net, clone->truesize); - } - - /* morph head into last received skb: prev. - * - * This allows callers of ipv6 conntrack defrag to continue - * to use the last skb(frag) passed into the reasm engine. - * The last skb frag 'silently' turns into the full reassembled skb. - * - * Since prev is also part of q->fragments we have to clone it first. - */ - if (head != prev) { - struct sk_buff *iter; - - fp = skb_clone(prev, GFP_ATOMIC); - if (!fp) - return false; - - fp->next = prev->next; - - iter = head; - while (iter) { - if (iter->next == prev) { - iter->next = fp; - break; - } - iter = iter->next; - } - - skb_morph(prev, head); - prev->next = head->next; - consume_skb(head); - head = prev; + goto err; } /* We have to remove fragment header from datagram and to relocate * header in order to calculate ICV correctly. */ - skb_network_header(head)[fq->nhoffset] = skb_transport_header(head)[0]; - memmove(head->head + sizeof(struct frag_hdr), head->head, - (head->data - head->head) - sizeof(struct frag_hdr)); - head->mac_header += sizeof(struct frag_hdr); - head->network_header += sizeof(struct frag_hdr); - - skb_shinfo(head)->frag_list = head->next; - skb_reset_transport_header(head); - skb_push(head, head->data - skb_network_header(head)); - - for (fp = head->next; fp; fp = fp->next) { - head->data_len += fp->len; - head->len += fp->len; - if (head->ip_summed != fp->ip_summed) - head->ip_summed = CHECKSUM_NONE; - else if (head->ip_summed == CHECKSUM_COMPLETE) - head->csum = csum_add(head->csum, fp->csum); - head->truesize += fp->truesize; - fp->sk = NULL; - } - sub_frag_mem_limit(fq->q.net, head->truesize); + skb_network_header(skb)[fq->nhoffset] = skb_transport_header(skb)[0]; + memmove(skb->head + sizeof(struct frag_hdr), skb->head, + (skb->data - skb->head) - sizeof(struct frag_hdr)); + skb->mac_header += sizeof(struct frag_hdr); + skb->network_header += sizeof(struct frag_hdr); + + skb_reset_transport_header(skb); + + inet_frag_reasm_finish(&fq->q, skb, reasm_data); - head->ignore_df = 1; - skb_mark_not_on_list(head); - head->dev = dev; - head->tstamp = fq->q.stamp; - ipv6_hdr(head)->payload_len = htons(payload_len); - ipv6_change_dsfield(ipv6_hdr(head), 0xff, ecn); - IP6CB(head)->frag_max_size = sizeof(struct ipv6hdr) + fq->q.max_size; + skb->ignore_df = 1; + skb->dev = dev; + ipv6_hdr(skb)->payload_len = htons(payload_len); + ipv6_change_dsfield(ipv6_hdr(skb), 0xff, ecn); + IP6CB(skb)->frag_max_size = sizeof(struct ipv6hdr) + fq->q.max_size; /* Yes, and fold redundant checksum back. 8) */ - if (head->ip_summed == CHECKSUM_COMPLETE) - head->csum = csum_partial(skb_network_header(head), - skb_network_header_len(head), - head->csum); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->csum = csum_partial(skb_network_header(skb), + skb_network_header_len(skb), + skb->csum); - fq->q.fragments = NULL; fq->q.rb_fragments = RB_ROOT; fq->q.fragments_tail = NULL; + fq->q.last_run_head = NULL; - return true; + return 0; + +err: + inet_frag_kill(&fq->q); + return -EINVAL; } /* @@ -542,7 +442,6 @@ find_prev_fhdr(struct sk_buff *skb, u8 *prevhdrp, int *prevhoff, int *fhoff) int nf_ct_frag6_gather(struct net *net, struct sk_buff *skb, u32 user) { u16 savethdr = skb->transport_header; - struct net_device *dev = skb->dev; int fhoff, nhoff, ret; struct frag_hdr *fhdr; struct frag_queue *fq; @@ -565,10 +464,6 @@ int nf_ct_frag6_gather(struct net *net, struct sk_buff *skb, u32 user) hdr = ipv6_hdr(skb); fhdr = (struct frag_hdr *)skb_transport_header(skb); - if (skb->len - skb_network_offset(skb) < IPV6_MIN_MTU && - fhdr->frag_off & htons(IP6_MF)) - return -EINVAL; - skb_orphan(skb); fq = fq_find(net, fhdr->identification, user, hdr, skb->dev ? skb->dev->ifindex : 0); @@ -580,31 +475,17 @@ int nf_ct_frag6_gather(struct net *net, struct sk_buff *skb, u32 user) spin_lock_bh(&fq->q.lock); ret = nf_ct_frag6_queue(fq, skb, fhdr, nhoff); - if (ret < 0) { - if (ret == -EPROTO) { - skb->transport_header = savethdr; - ret = 0; - } - goto out_unlock; + if (ret == -EPROTO) { + skb->transport_header = savethdr; + ret = 0; } /* after queue has assumed skb ownership, only 0 or -EINPROGRESS * must be returned. */ - ret = -EINPROGRESS; - if (fq->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && - fq->q.meat == fq->q.len) { - unsigned long orefdst = skb->_skb_refdst; - - skb->_skb_refdst = 0UL; - if (nf_ct_frag6_reasm(fq, skb, dev)) - ret = 0; - skb->_skb_refdst = orefdst; - } else { - skb_dst_drop(skb); - } + if (ret) + ret = -EINPROGRESS; -out_unlock: spin_unlock_bh(&fq->q.lock); inet_frag_put(&fq->q); return ret; diff --git a/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c b/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c deleted file mode 100644 index 23022447eb49..000000000000 --- a/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c +++ /dev/null @@ -1,411 +0,0 @@ -/* - * Copyright (c) 2011 Patrick McHardy <kaber@trash.net> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - * - * Development of IPv6 NAT funded by Astaro. - */ -#include <linux/types.h> -#include <linux/module.h> -#include <linux/skbuff.h> -#include <linux/ipv6.h> -#include <linux/netfilter.h> -#include <linux/netfilter_ipv6.h> -#include <net/secure_seq.h> -#include <net/checksum.h> -#include <net/ip6_checksum.h> -#include <net/ip6_route.h> -#include <net/ipv6.h> - -#include <net/netfilter/nf_conntrack_core.h> -#include <net/netfilter/nf_conntrack.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/netfilter/nf_nat_l4proto.h> - -static const struct nf_nat_l3proto nf_nat_l3proto_ipv6; - -#ifdef CONFIG_XFRM -static void nf_nat_ipv6_decode_session(struct sk_buff *skb, - const struct nf_conn *ct, - enum ip_conntrack_dir dir, - unsigned long statusbit, - struct flowi *fl) -{ - const struct nf_conntrack_tuple *t = &ct->tuplehash[dir].tuple; - struct flowi6 *fl6 = &fl->u.ip6; - - if (ct->status & statusbit) { - fl6->daddr = t->dst.u3.in6; - if (t->dst.protonum == IPPROTO_TCP || - t->dst.protonum == IPPROTO_UDP || - t->dst.protonum == IPPROTO_UDPLITE || - t->dst.protonum == IPPROTO_DCCP || - t->dst.protonum == IPPROTO_SCTP) - fl6->fl6_dport = t->dst.u.all; - } - - statusbit ^= IPS_NAT_MASK; - - if (ct->status & statusbit) { - fl6->saddr = t->src.u3.in6; - if (t->dst.protonum == IPPROTO_TCP || - t->dst.protonum == IPPROTO_UDP || - t->dst.protonum == IPPROTO_UDPLITE || - t->dst.protonum == IPPROTO_DCCP || - t->dst.protonum == IPPROTO_SCTP) - fl6->fl6_sport = t->src.u.all; - } -} -#endif - -static bool nf_nat_ipv6_manip_pkt(struct sk_buff *skb, - unsigned int iphdroff, - const struct nf_conntrack_tuple *target, - enum nf_nat_manip_type maniptype) -{ - struct ipv6hdr *ipv6h; - __be16 frag_off; - int hdroff; - u8 nexthdr; - - if (!skb_make_writable(skb, iphdroff + sizeof(*ipv6h))) - return false; - - ipv6h = (void *)skb->data + iphdroff; - nexthdr = ipv6h->nexthdr; - hdroff = ipv6_skip_exthdr(skb, iphdroff + sizeof(*ipv6h), - &nexthdr, &frag_off); - if (hdroff < 0) - goto manip_addr; - - if ((frag_off & htons(~0x7)) == 0 && - !nf_nat_l4proto_manip_pkt(skb, &nf_nat_l3proto_ipv6, iphdroff, hdroff, - target, maniptype)) - return false; - - /* must reload, offset might have changed */ - ipv6h = (void *)skb->data + iphdroff; - -manip_addr: - if (maniptype == NF_NAT_MANIP_SRC) - ipv6h->saddr = target->src.u3.in6; - else - ipv6h->daddr = target->dst.u3.in6; - - return true; -} - -static void nf_nat_ipv6_csum_update(struct sk_buff *skb, - unsigned int iphdroff, __sum16 *check, - const struct nf_conntrack_tuple *t, - enum nf_nat_manip_type maniptype) -{ - const struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + iphdroff); - const struct in6_addr *oldip, *newip; - - if (maniptype == NF_NAT_MANIP_SRC) { - oldip = &ipv6h->saddr; - newip = &t->src.u3.in6; - } else { - oldip = &ipv6h->daddr; - newip = &t->dst.u3.in6; - } - inet_proto_csum_replace16(check, skb, oldip->s6_addr32, - newip->s6_addr32, true); -} - -static void nf_nat_ipv6_csum_recalc(struct sk_buff *skb, - u8 proto, void *data, __sum16 *check, - int datalen, int oldlen) -{ - if (skb->ip_summed != CHECKSUM_PARTIAL) { - const struct ipv6hdr *ipv6h = ipv6_hdr(skb); - - skb->ip_summed = CHECKSUM_PARTIAL; - skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) + - (data - (void *)skb->data); - skb->csum_offset = (void *)check - data; - *check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, - datalen, proto, 0); - } else - inet_proto_csum_replace2(check, skb, - htons(oldlen), htons(datalen), true); -} - -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) -static int nf_nat_ipv6_nlattr_to_range(struct nlattr *tb[], - struct nf_nat_range2 *range) -{ - if (tb[CTA_NAT_V6_MINIP]) { - nla_memcpy(&range->min_addr.ip6, tb[CTA_NAT_V6_MINIP], - sizeof(struct in6_addr)); - range->flags |= NF_NAT_RANGE_MAP_IPS; - } - - if (tb[CTA_NAT_V6_MAXIP]) - nla_memcpy(&range->max_addr.ip6, tb[CTA_NAT_V6_MAXIP], - sizeof(struct in6_addr)); - else - range->max_addr = range->min_addr; - - return 0; -} -#endif - -static const struct nf_nat_l3proto nf_nat_l3proto_ipv6 = { - .l3proto = NFPROTO_IPV6, - .manip_pkt = nf_nat_ipv6_manip_pkt, - .csum_update = nf_nat_ipv6_csum_update, - .csum_recalc = nf_nat_ipv6_csum_recalc, -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .nlattr_to_range = nf_nat_ipv6_nlattr_to_range, -#endif -#ifdef CONFIG_XFRM - .decode_session = nf_nat_ipv6_decode_session, -#endif -}; - -int nf_nat_icmpv6_reply_translation(struct sk_buff *skb, - struct nf_conn *ct, - enum ip_conntrack_info ctinfo, - unsigned int hooknum, - unsigned int hdrlen) -{ - struct { - struct icmp6hdr icmp6; - struct ipv6hdr ip6; - } *inside; - enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); - enum nf_nat_manip_type manip = HOOK2MANIP(hooknum); - struct nf_conntrack_tuple target; - unsigned long statusbit; - - WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY); - - if (!skb_make_writable(skb, hdrlen + sizeof(*inside))) - return 0; - if (nf_ip6_checksum(skb, hooknum, hdrlen, IPPROTO_ICMPV6)) - return 0; - - inside = (void *)skb->data + hdrlen; - if (inside->icmp6.icmp6_type == NDISC_REDIRECT) { - if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK) - return 0; - if (ct->status & IPS_NAT_MASK) - return 0; - } - - if (manip == NF_NAT_MANIP_SRC) - statusbit = IPS_SRC_NAT; - else - statusbit = IPS_DST_NAT; - - /* Invert if this is reply direction */ - if (dir == IP_CT_DIR_REPLY) - statusbit ^= IPS_NAT_MASK; - - if (!(ct->status & statusbit)) - return 1; - - if (!nf_nat_ipv6_manip_pkt(skb, hdrlen + sizeof(inside->icmp6), - &ct->tuplehash[!dir].tuple, !manip)) - return 0; - - if (skb->ip_summed != CHECKSUM_PARTIAL) { - struct ipv6hdr *ipv6h = ipv6_hdr(skb); - inside = (void *)skb->data + hdrlen; - inside->icmp6.icmp6_cksum = 0; - inside->icmp6.icmp6_cksum = - csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, - skb->len - hdrlen, IPPROTO_ICMPV6, - skb_checksum(skb, hdrlen, - skb->len - hdrlen, 0)); - } - - nf_ct_invert_tuplepr(&target, &ct->tuplehash[!dir].tuple); - if (!nf_nat_ipv6_manip_pkt(skb, 0, &target, manip)) - return 0; - - return 1; -} -EXPORT_SYMBOL_GPL(nf_nat_icmpv6_reply_translation); - -static unsigned int -nf_nat_ipv6_fn(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - __be16 frag_off; - int hdrlen; - u8 nexthdr; - - ct = nf_ct_get(skb, &ctinfo); - /* Can't track? It's not due to stress, or conntrack would - * have dropped it. Hence it's the user's responsibilty to - * packet filter it out, or implement conntrack/NAT for that - * protocol. 8) --RR - */ - if (!ct) - return NF_ACCEPT; - - if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) { - nexthdr = ipv6_hdr(skb)->nexthdr; - hdrlen = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), - &nexthdr, &frag_off); - - if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { - if (!nf_nat_icmpv6_reply_translation(skb, ct, ctinfo, - state->hook, - hdrlen)) - return NF_DROP; - else - return NF_ACCEPT; - } - } - - return nf_nat_inet_fn(priv, skb, state); -} - -static unsigned int -nf_nat_ipv6_in(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - unsigned int ret; - struct in6_addr daddr = ipv6_hdr(skb)->daddr; - - ret = nf_nat_ipv6_fn(priv, skb, state); - if (ret != NF_DROP && ret != NF_STOLEN && - ipv6_addr_cmp(&daddr, &ipv6_hdr(skb)->daddr)) - skb_dst_drop(skb); - - return ret; -} - -static unsigned int -nf_nat_ipv6_out(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ -#ifdef CONFIG_XFRM - const struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - int err; -#endif - unsigned int ret; - - ret = nf_nat_ipv6_fn(priv, skb, state); -#ifdef CONFIG_XFRM - if (ret != NF_DROP && ret != NF_STOLEN && - !(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) && - (ct = nf_ct_get(skb, &ctinfo)) != NULL) { - enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); - - if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, - &ct->tuplehash[!dir].tuple.dst.u3) || - (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 && - ct->tuplehash[dir].tuple.src.u.all != - ct->tuplehash[!dir].tuple.dst.u.all)) { - err = nf_xfrm_me_harder(state->net, skb, AF_INET6); - if (err < 0) - ret = NF_DROP_ERR(err); - } - } -#endif - return ret; -} - -static unsigned int -nf_nat_ipv6_local_fn(void *priv, struct sk_buff *skb, - const struct nf_hook_state *state) -{ - const struct nf_conn *ct; - enum ip_conntrack_info ctinfo; - unsigned int ret; - int err; - - ret = nf_nat_ipv6_fn(priv, skb, state); - if (ret != NF_DROP && ret != NF_STOLEN && - (ct = nf_ct_get(skb, &ctinfo)) != NULL) { - enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); - - if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3, - &ct->tuplehash[!dir].tuple.src.u3)) { - err = ip6_route_me_harder(state->net, skb); - if (err < 0) - ret = NF_DROP_ERR(err); - } -#ifdef CONFIG_XFRM - else if (!(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) && - ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 && - ct->tuplehash[dir].tuple.dst.u.all != - ct->tuplehash[!dir].tuple.src.u.all) { - err = nf_xfrm_me_harder(state->net, skb, AF_INET6); - if (err < 0) - ret = NF_DROP_ERR(err); - } -#endif - } - return ret; -} - -static const struct nf_hook_ops nf_nat_ipv6_ops[] = { - /* Before packet filtering, change destination */ - { - .hook = nf_nat_ipv6_in, - .pf = NFPROTO_IPV6, - .hooknum = NF_INET_PRE_ROUTING, - .priority = NF_IP6_PRI_NAT_DST, - }, - /* After packet filtering, change source */ - { - .hook = nf_nat_ipv6_out, - .pf = NFPROTO_IPV6, - .hooknum = NF_INET_POST_ROUTING, - .priority = NF_IP6_PRI_NAT_SRC, - }, - /* Before packet filtering, change destination */ - { - .hook = nf_nat_ipv6_local_fn, - .pf = NFPROTO_IPV6, - .hooknum = NF_INET_LOCAL_OUT, - .priority = NF_IP6_PRI_NAT_DST, - }, - /* After packet filtering, change source */ - { - .hook = nf_nat_ipv6_fn, - .pf = NFPROTO_IPV6, - .hooknum = NF_INET_LOCAL_IN, - .priority = NF_IP6_PRI_NAT_SRC, - }, -}; - -int nf_nat_l3proto_ipv6_register_fn(struct net *net, const struct nf_hook_ops *ops) -{ - return nf_nat_register_fn(net, ops, nf_nat_ipv6_ops, ARRAY_SIZE(nf_nat_ipv6_ops)); -} -EXPORT_SYMBOL_GPL(nf_nat_l3proto_ipv6_register_fn); - -void nf_nat_l3proto_ipv6_unregister_fn(struct net *net, const struct nf_hook_ops *ops) -{ - nf_nat_unregister_fn(net, ops, ARRAY_SIZE(nf_nat_ipv6_ops)); -} -EXPORT_SYMBOL_GPL(nf_nat_l3proto_ipv6_unregister_fn); - -static int __init nf_nat_l3proto_ipv6_init(void) -{ - return nf_nat_l3proto_register(&nf_nat_l3proto_ipv6); -} - -static void __exit nf_nat_l3proto_ipv6_exit(void) -{ - nf_nat_l3proto_unregister(&nf_nat_l3proto_ipv6); -} - -MODULE_LICENSE("GPL"); -MODULE_ALIAS("nf-nat-" __stringify(AF_INET6)); - -module_init(nf_nat_l3proto_ipv6_init); -module_exit(nf_nat_l3proto_ipv6_exit); diff --git a/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c b/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c deleted file mode 100644 index 0ad0da5a2600..000000000000 --- a/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Copyright (c) 2011 Patrick McHardy <kaber@trash.net> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - * - * Based on Rusty Russell's IPv6 MASQUERADE target. Development of IPv6 - * NAT funded by Astaro. - */ - -#include <linux/kernel.h> -#include <linux/atomic.h> -#include <linux/netdevice.h> -#include <linux/ipv6.h> -#include <linux/netfilter.h> -#include <linux/netfilter_ipv6.h> -#include <net/netfilter/nf_nat.h> -#include <net/addrconf.h> -#include <net/ipv6.h> -#include <net/netfilter/ipv6/nf_nat_masquerade.h> - -#define MAX_WORK_COUNT 16 - -static atomic_t v6_worker_count; - -unsigned int -nf_nat_masquerade_ipv6(struct sk_buff *skb, const struct nf_nat_range2 *range, - const struct net_device *out) -{ - enum ip_conntrack_info ctinfo; - struct nf_conn_nat *nat; - struct in6_addr src; - struct nf_conn *ct; - struct nf_nat_range2 newrange; - - ct = nf_ct_get(skb, &ctinfo); - WARN_ON(!(ct && (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || - ctinfo == IP_CT_RELATED_REPLY))); - - if (ipv6_dev_get_saddr(nf_ct_net(ct), out, - &ipv6_hdr(skb)->daddr, 0, &src) < 0) - return NF_DROP; - - nat = nf_ct_nat_ext_add(ct); - if (nat) - nat->masq_index = out->ifindex; - - newrange.flags = range->flags | NF_NAT_RANGE_MAP_IPS; - newrange.min_addr.in6 = src; - newrange.max_addr.in6 = src; - newrange.min_proto = range->min_proto; - newrange.max_proto = range->max_proto; - - return nf_nat_setup_info(ct, &newrange, NF_NAT_MANIP_SRC); -} -EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6); - -static int device_cmp(struct nf_conn *ct, void *ifindex) -{ - const struct nf_conn_nat *nat = nfct_nat(ct); - - if (!nat) - return 0; - if (nf_ct_l3num(ct) != NFPROTO_IPV6) - return 0; - return nat->masq_index == (int)(long)ifindex; -} - -static int masq_device_event(struct notifier_block *this, - unsigned long event, void *ptr) -{ - const struct net_device *dev = netdev_notifier_info_to_dev(ptr); - struct net *net = dev_net(dev); - - if (event == NETDEV_DOWN) - nf_ct_iterate_cleanup_net(net, device_cmp, - (void *)(long)dev->ifindex, 0, 0); - - return NOTIFY_DONE; -} - -static struct notifier_block masq_dev_notifier = { - .notifier_call = masq_device_event, -}; - -struct masq_dev_work { - struct work_struct work; - struct net *net; - struct in6_addr addr; - int ifindex; -}; - -static int inet_cmp(struct nf_conn *ct, void *work) -{ - struct masq_dev_work *w = (struct masq_dev_work *)work; - struct nf_conntrack_tuple *tuple; - - if (!device_cmp(ct, (void *)(long)w->ifindex)) - return 0; - - tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; - - return ipv6_addr_equal(&w->addr, &tuple->dst.u3.in6); -} - -static void iterate_cleanup_work(struct work_struct *work) -{ - struct masq_dev_work *w; - - w = container_of(work, struct masq_dev_work, work); - - nf_ct_iterate_cleanup_net(w->net, inet_cmp, (void *)w, 0, 0); - - put_net(w->net); - kfree(w); - atomic_dec(&v6_worker_count); - module_put(THIS_MODULE); -} - -/* ipv6 inet notifier is an atomic notifier, i.e. we cannot - * schedule. - * - * Unfortunately, nf_ct_iterate_cleanup_net can run for a long - * time if there are lots of conntracks and the system - * handles high softirq load, so it frequently calls cond_resched - * while iterating the conntrack table. - * - * So we defer nf_ct_iterate_cleanup_net walk to the system workqueue. - * - * As we can have 'a lot' of inet_events (depending on amount - * of ipv6 addresses being deleted), we also need to add an upper - * limit to the number of queued work items. - */ -static int masq_inet6_event(struct notifier_block *this, - unsigned long event, void *ptr) -{ - struct inet6_ifaddr *ifa = ptr; - const struct net_device *dev; - struct masq_dev_work *w; - struct net *net; - - if (event != NETDEV_DOWN || - atomic_read(&v6_worker_count) >= MAX_WORK_COUNT) - return NOTIFY_DONE; - - dev = ifa->idev->dev; - net = maybe_get_net(dev_net(dev)); - if (!net) - return NOTIFY_DONE; - - if (!try_module_get(THIS_MODULE)) - goto err_module; - - w = kmalloc(sizeof(*w), GFP_ATOMIC); - if (w) { - atomic_inc(&v6_worker_count); - - INIT_WORK(&w->work, iterate_cleanup_work); - w->ifindex = dev->ifindex; - w->net = net; - w->addr = ifa->addr; - schedule_work(&w->work); - - return NOTIFY_DONE; - } - - module_put(THIS_MODULE); - err_module: - put_net(net); - return NOTIFY_DONE; -} - -static struct notifier_block masq_inet6_notifier = { - .notifier_call = masq_inet6_event, -}; - -static int masq_refcnt; -static DEFINE_MUTEX(masq_mutex); - -int nf_nat_masquerade_ipv6_register_notifier(void) -{ - int ret = 0; - - mutex_lock(&masq_mutex); - /* check if the notifier is already set */ - if (++masq_refcnt > 1) - goto out_unlock; - - ret = register_netdevice_notifier(&masq_dev_notifier); - if (ret) - goto err_dec; - - ret = register_inet6addr_notifier(&masq_inet6_notifier); - if (ret) - goto err_unregister; - - mutex_unlock(&masq_mutex); - return ret; - -err_unregister: - unregister_netdevice_notifier(&masq_dev_notifier); -err_dec: - masq_refcnt--; -out_unlock: - mutex_unlock(&masq_mutex); - return ret; -} -EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6_register_notifier); - -void nf_nat_masquerade_ipv6_unregister_notifier(void) -{ - mutex_lock(&masq_mutex); - /* check if the notifier still has clients */ - if (--masq_refcnt > 0) - goto out_unlock; - - unregister_inet6addr_notifier(&masq_inet6_notifier); - unregister_netdevice_notifier(&masq_dev_notifier); -out_unlock: - mutex_unlock(&masq_mutex); -} -EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6_unregister_notifier); diff --git a/net/ipv6/netfilter/nf_reject_ipv6.c b/net/ipv6/netfilter/nf_reject_ipv6.c index b9c8a763c863..02e9228641e0 100644 --- a/net/ipv6/netfilter/nf_reject_ipv6.c +++ b/net/ipv6/netfilter/nf_reject_ipv6.c @@ -233,6 +233,9 @@ static bool reject6_csum_ok(struct sk_buff *skb, int hook) if (thoff < 0 || thoff >= skb->len || (fo & htons(~0x7)) != 0) return false; + if (!nf_reject_verify_csum(proto)) + return true; + return nf_ip6_checksum(skb, hook, thoff, proto) == 0; } diff --git a/net/ipv6/netfilter/nft_chain_nat_ipv6.c b/net/ipv6/netfilter/nft_chain_nat_ipv6.c deleted file mode 100644 index 8a081ad7d5db..000000000000 --- a/net/ipv6/netfilter/nft_chain_nat_ipv6.c +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2011 Patrick McHardy <kaber@trash.net> - * Copyright (c) 2012 Intel Corporation - * - * This program is free software; you can redistribute it and/or modify it - * under the terms and conditions of the GNU General Public License, - * version 2, as published by the Free Software Foundation. - * - */ - -#include <linux/module.h> -#include <linux/init.h> -#include <linux/list.h> -#include <linux/skbuff.h> -#include <linux/ip.h> -#include <linux/netfilter.h> -#include <linux/netfilter_ipv6.h> -#include <linux/netfilter/nf_tables.h> -#include <net/netfilter/nf_conntrack.h> -#include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_tables.h> -#include <net/netfilter/nf_tables_ipv6.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/ipv6.h> - -static unsigned int nft_nat_do_chain(void *priv, - struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct nft_pktinfo pkt; - - nft_set_pktinfo(&pkt, skb, state); - nft_set_pktinfo_ipv6(&pkt, skb); - - return nft_do_chain(&pkt, priv); -} - -static int nft_nat_ipv6_reg(struct net *net, const struct nf_hook_ops *ops) -{ - return nf_nat_l3proto_ipv6_register_fn(net, ops); -} - -static void nft_nat_ipv6_unreg(struct net *net, const struct nf_hook_ops *ops) -{ - nf_nat_l3proto_ipv6_unregister_fn(net, ops); -} - -static const struct nft_chain_type nft_chain_nat_ipv6 = { - .name = "nat", - .type = NFT_CHAIN_T_NAT, - .family = NFPROTO_IPV6, - .owner = THIS_MODULE, - .hook_mask = (1 << NF_INET_PRE_ROUTING) | - (1 << NF_INET_POST_ROUTING) | - (1 << NF_INET_LOCAL_OUT) | - (1 << NF_INET_LOCAL_IN), - .hooks = { - [NF_INET_PRE_ROUTING] = nft_nat_do_chain, - [NF_INET_POST_ROUTING] = nft_nat_do_chain, - [NF_INET_LOCAL_OUT] = nft_nat_do_chain, - [NF_INET_LOCAL_IN] = nft_nat_do_chain, - }, - .ops_register = nft_nat_ipv6_reg, - .ops_unregister = nft_nat_ipv6_unreg, -}; - -static int __init nft_chain_nat_ipv6_init(void) -{ - nft_register_chain_type(&nft_chain_nat_ipv6); - - return 0; -} - -static void __exit nft_chain_nat_ipv6_exit(void) -{ - nft_unregister_chain_type(&nft_chain_nat_ipv6); -} - -module_init(nft_chain_nat_ipv6_init); -module_exit(nft_chain_nat_ipv6_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Tomasz Bursztyka <tomasz.bursztyka@linux.intel.com>"); -MODULE_ALIAS_NFT_CHAIN(AF_INET6, "nat"); diff --git a/net/ipv6/netfilter/nft_fib_ipv6.c b/net/ipv6/netfilter/nft_fib_ipv6.c index 36be3cf0adef..73cdc0bc63f7 100644 --- a/net/ipv6/netfilter/nft_fib_ipv6.c +++ b/net/ipv6/netfilter/nft_fib_ipv6.c @@ -59,7 +59,6 @@ static u32 __nft_fib6_eval_type(const struct nft_fib *priv, struct ipv6hdr *iph) { const struct net_device *dev = NULL; - const struct nf_ipv6_ops *v6ops; int route_err, addrtype; struct rt6_info *rt; struct flowi6 fl6 = { @@ -68,10 +67,6 @@ static u32 __nft_fib6_eval_type(const struct nft_fib *priv, }; u32 ret = 0; - v6ops = nf_get_ipv6_ops(); - if (!v6ops) - return RTN_UNREACHABLE; - if (priv->flags & NFTA_FIB_F_IIF) dev = nft_in(pkt); else if (priv->flags & NFTA_FIB_F_OIF) @@ -79,10 +74,10 @@ static u32 __nft_fib6_eval_type(const struct nft_fib *priv, nft_fib6_flowi_init(&fl6, priv, pkt, dev, iph); - if (dev && v6ops->chk_addr(nft_net(pkt), &fl6.daddr, dev, true)) + if (dev && nf_ipv6_chk_addr(nft_net(pkt), &fl6.daddr, dev, true)) ret = RTN_LOCAL; - route_err = v6ops->route(nft_net(pkt), (struct dst_entry **)&rt, + route_err = nf_ip6_route(nft_net(pkt), (struct dst_entry **)&rt, flowi6_to_flowi(&fl6), false); if (route_err) goto err; diff --git a/net/ipv6/netfilter/nft_masq_ipv6.c b/net/ipv6/netfilter/nft_masq_ipv6.c deleted file mode 100644 index e06c82e9dfcd..000000000000 --- a/net/ipv6/netfilter/nft_masq_ipv6.c +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2014 Arturo Borrero Gonzalez <arturo@debian.org> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - */ - -#include <linux/kernel.h> -#include <linux/init.h> -#include <linux/module.h> -#include <linux/netlink.h> -#include <linux/netfilter.h> -#include <linux/netfilter/nf_tables.h> -#include <net/netfilter/nf_tables.h> -#include <net/netfilter/nf_nat.h> -#include <net/netfilter/nft_masq.h> -#include <net/netfilter/ipv6/nf_nat_masquerade.h> - -static void nft_masq_ipv6_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) -{ - struct nft_masq *priv = nft_expr_priv(expr); - struct nf_nat_range2 range; - - memset(&range, 0, sizeof(range)); - range.flags = priv->flags; - if (priv->sreg_proto_min) { - range.min_proto.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_min]); - range.max_proto.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_max]); - } - regs->verdict.code = nf_nat_masquerade_ipv6(pkt->skb, &range, - nft_out(pkt)); -} - -static void -nft_masq_ipv6_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) -{ - nf_ct_netns_put(ctx->net, NFPROTO_IPV6); -} - -static struct nft_expr_type nft_masq_ipv6_type; -static const struct nft_expr_ops nft_masq_ipv6_ops = { - .type = &nft_masq_ipv6_type, - .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), - .eval = nft_masq_ipv6_eval, - .init = nft_masq_init, - .destroy = nft_masq_ipv6_destroy, - .dump = nft_masq_dump, - .validate = nft_masq_validate, -}; - -static struct nft_expr_type nft_masq_ipv6_type __read_mostly = { - .family = NFPROTO_IPV6, - .name = "masq", - .ops = &nft_masq_ipv6_ops, - .policy = nft_masq_policy, - .maxattr = NFTA_MASQ_MAX, - .owner = THIS_MODULE, -}; - -static int __init nft_masq_ipv6_module_init(void) -{ - int ret; - - ret = nft_register_expr(&nft_masq_ipv6_type); - if (ret < 0) - return ret; - - ret = nf_nat_masquerade_ipv6_register_notifier(); - if (ret) - nft_unregister_expr(&nft_masq_ipv6_type); - - return ret; -} - -static void __exit nft_masq_ipv6_module_exit(void) -{ - nft_unregister_expr(&nft_masq_ipv6_type); - nf_nat_masquerade_ipv6_unregister_notifier(); -} - -module_init(nft_masq_ipv6_module_init); -module_exit(nft_masq_ipv6_module_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Arturo Borrero Gonzalez <arturo@debian.org>"); -MODULE_ALIAS_NFT_AF_EXPR(AF_INET6, "masq"); diff --git a/net/ipv6/netfilter/nft_redir_ipv6.c b/net/ipv6/netfilter/nft_redir_ipv6.c deleted file mode 100644 index 74269865acc8..000000000000 --- a/net/ipv6/netfilter/nft_redir_ipv6.c +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2014 Arturo Borrero Gonzalez <arturo@debian.org> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - */ - -#include <linux/kernel.h> -#include <linux/init.h> -#include <linux/module.h> -#include <linux/netlink.h> -#include <linux/netfilter.h> -#include <linux/netfilter/nf_tables.h> -#include <net/netfilter/nf_tables.h> -#include <net/netfilter/nf_nat.h> -#include <net/netfilter/nft_redir.h> -#include <net/netfilter/nf_nat_redirect.h> - -static void nft_redir_ipv6_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) -{ - struct nft_redir *priv = nft_expr_priv(expr); - struct nf_nat_range2 range; - - memset(&range, 0, sizeof(range)); - if (priv->sreg_proto_min) { - range.min_proto.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_min]); - range.max_proto.all = (__force __be16)nft_reg_load16( - ®s->data[priv->sreg_proto_max]); - range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; - } - - range.flags |= priv->flags; - - regs->verdict.code = - nf_nat_redirect_ipv6(pkt->skb, &range, nft_hook(pkt)); -} - -static void -nft_redir_ipv6_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) -{ - nf_ct_netns_put(ctx->net, NFPROTO_IPV6); -} - -static struct nft_expr_type nft_redir_ipv6_type; -static const struct nft_expr_ops nft_redir_ipv6_ops = { - .type = &nft_redir_ipv6_type, - .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), - .eval = nft_redir_ipv6_eval, - .init = nft_redir_init, - .destroy = nft_redir_ipv6_destroy, - .dump = nft_redir_dump, - .validate = nft_redir_validate, -}; - -static struct nft_expr_type nft_redir_ipv6_type __read_mostly = { - .family = NFPROTO_IPV6, - .name = "redir", - .ops = &nft_redir_ipv6_ops, - .policy = nft_redir_policy, - .maxattr = NFTA_REDIR_MAX, - .owner = THIS_MODULE, -}; - -static int __init nft_redir_ipv6_module_init(void) -{ - return nft_register_expr(&nft_redir_ipv6_type); -} - -static void __exit nft_redir_ipv6_module_exit(void) -{ - nft_unregister_expr(&nft_redir_ipv6_type); -} - -module_init(nft_redir_ipv6_module_init); -module_exit(nft_redir_ipv6_module_exit); - -MODULE_LICENSE("GPL"); -MODULE_AUTHOR("Arturo Borrero Gonzalez <arturo@debian.org>"); -MODULE_ALIAS_NFT_AF_EXPR(AF_INET6, "redir"); diff --git a/net/ipv6/reassembly.c b/net/ipv6/reassembly.c index 36a3d8dc61f5..1a832f5e190b 100644 --- a/net/ipv6/reassembly.c +++ b/net/ipv6/reassembly.c @@ -69,8 +69,8 @@ static u8 ip6_frag_ecn(const struct ipv6hdr *ipv6h) static struct inet_frags ip6_frags; -static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *prev, - struct net_device *dev); +static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev); static void ip6_frag_expire(struct timer_list *t) { @@ -111,21 +111,26 @@ static int ip6_frag_queue(struct frag_queue *fq, struct sk_buff *skb, struct frag_hdr *fhdr, int nhoff, u32 *prob_offset) { - struct sk_buff *prev, *next; - struct net_device *dev; - int offset, end, fragsize; struct net *net = dev_net(skb_dst(skb)->dev); + int offset, end, fragsize; + struct sk_buff *prev_tail; + struct net_device *dev; + int err = -ENOENT; u8 ecn; if (fq->q.flags & INET_FRAG_COMPLETE) goto err; + err = -EINVAL; offset = ntohs(fhdr->frag_off) & ~0x7; end = offset + (ntohs(ipv6_hdr(skb)->payload_len) - ((u8 *)(fhdr + 1) - (u8 *)(ipv6_hdr(skb) + 1))); if ((unsigned int)end > IPV6_MAXPLEN) { *prob_offset = (u8 *)&fhdr->frag_off - skb_network_header(skb); + /* note that if prob_offset is set, the skb is freed elsewhere, + * we do not free it here. + */ return -1; } @@ -170,62 +175,27 @@ static int ip6_frag_queue(struct frag_queue *fq, struct sk_buff *skb, if (end == offset) goto discard_fq; + err = -ENOMEM; /* Point into the IP datagram 'data' part. */ if (!pskb_pull(skb, (u8 *) (fhdr + 1) - skb->data)) goto discard_fq; - if (pskb_trim_rcsum(skb, end - offset)) + err = pskb_trim_rcsum(skb, end - offset); + if (err) goto discard_fq; - /* Find out which fragments are in front and at the back of us - * in the chain of fragments so far. We must know where to put - * this fragment, right? - */ - prev = fq->q.fragments_tail; - if (!prev || prev->ip_defrag_offset < offset) { - next = NULL; - goto found; - } - prev = NULL; - for (next = fq->q.fragments; next != NULL; next = next->next) { - if (next->ip_defrag_offset >= offset) - break; /* bingo! */ - prev = next; - } - -found: - /* RFC5722, Section 4, amended by Errata ID : 3089 - * When reassembling an IPv6 datagram, if - * one or more its constituent fragments is determined to be an - * overlapping fragment, the entire datagram (and any constituent - * fragments) MUST be silently discarded. - */ - - /* Check for overlap with preceding fragment. */ - if (prev && - (prev->ip_defrag_offset + prev->len) > offset) - goto discard_fq; - - /* Look for overlap with succeeding segment. */ - if (next && next->ip_defrag_offset < end) - goto discard_fq; - - /* Note : skb->ip_defrag_offset and skb->sk share the same location */ + /* Note : skb->rbnode and skb->dev share the same location. */ dev = skb->dev; - if (dev) - fq->iif = dev->ifindex; /* Makes sure compiler wont do silly aliasing games */ barrier(); - skb->ip_defrag_offset = offset; - /* Insert this fragment in the chain of fragments. */ - skb->next = next; - if (!next) - fq->q.fragments_tail = skb; - if (prev) - prev->next = skb; - else - fq->q.fragments = skb; + prev_tail = fq->q.fragments_tail; + err = inet_frag_queue_insert(&fq->q, skb, offset, end); + if (err) + goto insert_error; + + if (dev) + fq->iif = dev->ifindex; fq->q.stamp = skb->tstamp; fq->q.meat += skb->len; @@ -246,44 +216,48 @@ found: if (fq->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && fq->q.meat == fq->q.len) { - int res; unsigned long orefdst = skb->_skb_refdst; skb->_skb_refdst = 0UL; - res = ip6_frag_reasm(fq, prev, dev); + err = ip6_frag_reasm(fq, skb, prev_tail, dev); skb->_skb_refdst = orefdst; - return res; + return err; } skb_dst_drop(skb); - return -1; + return -EINPROGRESS; +insert_error: + if (err == IPFRAG_DUP) { + kfree_skb(skb); + return -EINVAL; + } + err = -EINVAL; + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_REASM_OVERLAPS); discard_fq: inet_frag_kill(&fq->q); -err: __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), IPSTATS_MIB_REASMFAILS); +err: kfree_skb(skb); - return -1; + return err; } /* * Check if this packet is complete. - * Returns NULL on failure by any reason, and pointer - * to current nexthdr field in reassembled frame. * * It is called with locked fq, and caller must check that * queue is eligible for reassembly i.e. it is not COMPLETE, * the last and the first frames arrived and all the bits are here. */ -static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *prev, - struct net_device *dev) +static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev) { struct net *net = container_of(fq->q.net, struct net, ipv6.frags); - struct sk_buff *fp, *head = fq->q.fragments; - int payload_len, delta; unsigned int nhoff; - int sum_truesize; + void *reasm_data; + int payload_len; u8 ecn; inet_frag_kill(&fq->q); @@ -292,128 +266,47 @@ static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *prev, if (unlikely(ecn == 0xff)) goto out_fail; - /* Make the one we just received the head. */ - if (prev) { - head = prev->next; - fp = skb_clone(head, GFP_ATOMIC); - - if (!fp) - goto out_oom; - - fp->next = head->next; - if (!fp->next) - fq->q.fragments_tail = fp; - prev->next = fp; - - skb_morph(head, fq->q.fragments); - head->next = fq->q.fragments->next; - - consume_skb(fq->q.fragments); - fq->q.fragments = head; - } - - WARN_ON(head == NULL); - WARN_ON(head->ip_defrag_offset != 0); + reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail); + if (!reasm_data) + goto out_oom; - /* Unfragmented part is taken from the first segment. */ - payload_len = ((head->data - skb_network_header(head)) - + payload_len = ((skb->data - skb_network_header(skb)) - sizeof(struct ipv6hdr) + fq->q.len - sizeof(struct frag_hdr)); if (payload_len > IPV6_MAXPLEN) goto out_oversize; - delta = - head->truesize; - - /* Head of list must not be cloned. */ - if (skb_unclone(head, GFP_ATOMIC)) - goto out_oom; - - delta += head->truesize; - if (delta) - add_frag_mem_limit(fq->q.net, delta); - - /* If the first fragment is fragmented itself, we split - * it to two chunks: the first with data and paged part - * and the second, holding only fragments. */ - if (skb_has_frag_list(head)) { - struct sk_buff *clone; - int i, plen = 0; - - clone = alloc_skb(0, GFP_ATOMIC); - if (!clone) - goto out_oom; - clone->next = head->next; - head->next = clone; - skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; - skb_frag_list_init(head); - for (i = 0; i < skb_shinfo(head)->nr_frags; i++) - plen += skb_frag_size(&skb_shinfo(head)->frags[i]); - clone->len = clone->data_len = head->data_len - plen; - head->data_len -= clone->len; - head->len -= clone->len; - clone->csum = 0; - clone->ip_summed = head->ip_summed; - add_frag_mem_limit(fq->q.net, clone->truesize); - } - /* We have to remove fragment header from datagram and to relocate * header in order to calculate ICV correctly. */ nhoff = fq->nhoffset; - skb_network_header(head)[nhoff] = skb_transport_header(head)[0]; - memmove(head->head + sizeof(struct frag_hdr), head->head, - (head->data - head->head) - sizeof(struct frag_hdr)); - if (skb_mac_header_was_set(head)) - head->mac_header += sizeof(struct frag_hdr); - head->network_header += sizeof(struct frag_hdr); - - skb_reset_transport_header(head); - skb_push(head, head->data - skb_network_header(head)); - - sum_truesize = head->truesize; - for (fp = head->next; fp;) { - bool headstolen; - int delta; - struct sk_buff *next = fp->next; - - sum_truesize += fp->truesize; - if (head->ip_summed != fp->ip_summed) - head->ip_summed = CHECKSUM_NONE; - else if (head->ip_summed == CHECKSUM_COMPLETE) - head->csum = csum_add(head->csum, fp->csum); - - if (skb_try_coalesce(head, fp, &headstolen, &delta)) { - kfree_skb_partial(fp, headstolen); - } else { - fp->sk = NULL; - if (!skb_shinfo(head)->frag_list) - skb_shinfo(head)->frag_list = fp; - head->data_len += fp->len; - head->len += fp->len; - head->truesize += fp->truesize; - } - fp = next; - } - sub_frag_mem_limit(fq->q.net, sum_truesize); + skb_network_header(skb)[nhoff] = skb_transport_header(skb)[0]; + memmove(skb->head + sizeof(struct frag_hdr), skb->head, + (skb->data - skb->head) - sizeof(struct frag_hdr)); + if (skb_mac_header_was_set(skb)) + skb->mac_header += sizeof(struct frag_hdr); + skb->network_header += sizeof(struct frag_hdr); + + skb_reset_transport_header(skb); + + inet_frag_reasm_finish(&fq->q, skb, reasm_data); - skb_mark_not_on_list(head); - head->dev = dev; - head->tstamp = fq->q.stamp; - ipv6_hdr(head)->payload_len = htons(payload_len); - ipv6_change_dsfield(ipv6_hdr(head), 0xff, ecn); - IP6CB(head)->nhoff = nhoff; - IP6CB(head)->flags |= IP6SKB_FRAGMENTED; - IP6CB(head)->frag_max_size = fq->q.max_size; + skb->dev = dev; + ipv6_hdr(skb)->payload_len = htons(payload_len); + ipv6_change_dsfield(ipv6_hdr(skb), 0xff, ecn); + IP6CB(skb)->nhoff = nhoff; + IP6CB(skb)->flags |= IP6SKB_FRAGMENTED; + IP6CB(skb)->frag_max_size = fq->q.max_size; /* Yes, and fold redundant checksum back. 8) */ - skb_postpush_rcsum(head, skb_network_header(head), - skb_network_header_len(head)); + skb_postpush_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); rcu_read_lock(); __IP6_INC_STATS(net, __in6_dev_get(dev), IPSTATS_MIB_REASMOKS); rcu_read_unlock(); - fq->q.fragments = NULL; fq->q.rb_fragments = RB_ROOT; fq->q.fragments_tail = NULL; + fq->q.last_run_head = NULL; return 1; out_oversize: @@ -464,10 +357,6 @@ static int ipv6_frag_rcv(struct sk_buff *skb) return 1; } - if (skb->len - skb_network_offset(skb) < IPV6_MIN_MTU && - fhdr->frag_off & htons(IP6_MF)) - goto fail_hdr; - iif = skb->dev ? skb->dev->ifindex : 0; fq = fq_find(net, fhdr->identification, hdr, iif); if (fq) { @@ -485,6 +374,7 @@ static int ipv6_frag_rcv(struct sk_buff *skb) if (prob_offset) { __IP6_INC_STATS(net, __in6_dev_get_safely(skb->dev), IPSTATS_MIB_INHDRERRORS); + /* icmpv6_param_prob() calls kfree_skb(skb) */ icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, prob_offset); } return ret; diff --git a/net/ipv6/route.c b/net/ipv6/route.c index 964491cf3672..4ef4bbdb49d4 100644 --- a/net/ipv6/route.c +++ b/net/ipv6/route.c @@ -1274,18 +1274,29 @@ static DEFINE_SPINLOCK(rt6_exception_lock); static void rt6_remove_exception(struct rt6_exception_bucket *bucket, struct rt6_exception *rt6_ex) { + struct fib6_info *from; struct net *net; if (!bucket || !rt6_ex) return; net = dev_net(rt6_ex->rt6i->dst.dev); + net->ipv6.rt6_stats->fib_rt_cache--; + + /* purge completely the exception to allow releasing the held resources: + * some [sk] cache may keep the dst around for unlimited time + */ + from = rcu_dereference_protected(rt6_ex->rt6i->from, + lockdep_is_held(&rt6_exception_lock)); + rcu_assign_pointer(rt6_ex->rt6i->from, NULL); + fib6_info_release(from); + dst_dev_put(&rt6_ex->rt6i->dst); + hlist_del_rcu(&rt6_ex->hlist); dst_release(&rt6_ex->rt6i->dst); kfree_rcu(rt6_ex, rcu); WARN_ON_ONCE(!bucket->depth); bucket->depth--; - net->ipv6.rt6_stats->fib_rt_cache--; } /* Remove oldest rt6_ex in bucket and free the memory @@ -1599,15 +1610,15 @@ static int rt6_remove_exception_rt(struct rt6_info *rt) static void rt6_update_exception_stamp_rt(struct rt6_info *rt) { struct rt6_exception_bucket *bucket; - struct fib6_info *from = rt->from; struct in6_addr *src_key = NULL; struct rt6_exception *rt6_ex; - - if (!from || - !(rt->rt6i_flags & RTF_CACHE)) - return; + struct fib6_info *from; rcu_read_lock(); + from = rcu_dereference(rt->from); + if (!from || !(rt->rt6i_flags & RTF_CACHE)) + goto unlock; + bucket = rcu_dereference(from->rt6i_exception_bucket); #ifdef CONFIG_IPV6_SUBTREES @@ -1626,6 +1637,7 @@ static void rt6_update_exception_stamp_rt(struct rt6_info *rt) if (rt6_ex) rt6_ex->stamp = jiffies; +unlock: rcu_read_unlock(); } @@ -2277,14 +2289,8 @@ static void rt6_do_update_pmtu(struct rt6_info *rt, u32 mtu) static bool rt6_cache_allowed_for_pmtu(const struct rt6_info *rt) { - bool from_set; - - rcu_read_lock(); - from_set = !!rcu_dereference(rt->from); - rcu_read_unlock(); - return !(rt->rt6i_flags & RTF_CACHE) && - (rt->rt6i_flags & RTF_PCPU || from_set); + (rt->rt6i_flags & RTF_PCPU || rcu_access_pointer(rt->from)); } static void __ip6_rt_update_pmtu(struct dst_entry *dst, const struct sock *sk, @@ -2742,20 +2748,24 @@ static int ip6_route_check_nh_onlink(struct net *net, u32 tbid = l3mdev_fib_table(dev) ? : RT_TABLE_MAIN; const struct in6_addr *gw_addr = &cfg->fc_gateway; u32 flags = RTF_LOCAL | RTF_ANYCAST | RTF_REJECT; + struct fib6_info *from; struct rt6_info *grt; int err; err = 0; grt = ip6_nh_lookup_table(net, cfg, gw_addr, tbid, 0); if (grt) { + rcu_read_lock(); + from = rcu_dereference(grt->from); if (!grt->dst.error && /* ignore match if it is the default route */ - grt->from && !ipv6_addr_any(&grt->from->fib6_dst.addr) && + from && !ipv6_addr_any(&from->fib6_dst.addr) && (grt->rt6i_flags & flags || dev != grt->dst.dev)) { NL_SET_ERR_MSG(extack, "Nexthop has invalid gateway or device mismatch"); err = -EINVAL; } + rcu_read_unlock(); ip6_rt_put(grt); } @@ -4166,6 +4176,10 @@ static int rtm_to_fib6_config(struct sk_buff *skb, struct nlmsghdr *nlh, cfg->fc_gateway = nla_get_in6_addr(tb[RTA_GATEWAY]); cfg->fc_flags |= RTF_GATEWAY; } + if (tb[RTA_VIA]) { + NL_SET_ERR_MSG(extack, "IPv6 does not support RTA_VIA attribute"); + goto errout; + } if (tb[RTA_DST]) { int plen = (rtm->rtm_dst_len + 7) >> 3; @@ -4649,7 +4663,7 @@ static int rt6_fill_node(struct net *net, struct sk_buff *skb, table = rt->fib6_table->tb6_id; else table = RT6_TABLE_UNSPEC; - rtm->rtm_table = table; + rtm->rtm_table = table < 256 ? table : RT_TABLE_COMPAT; if (nla_put_u32(skb, RTA_TABLE, table)) goto nla_put_failure; @@ -4812,6 +4826,73 @@ int rt6_dump_route(struct fib6_info *rt, void *p_arg) arg->cb->nlh->nlmsg_seq, flags); } +static int inet6_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid header for get route request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv6_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 128) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 128) || + rtm->rtm_table || rtm->rtm_protocol || rtm->rtm_scope || + rtm->rtm_type) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request"); + return -EINVAL; + } + if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid flags for get route request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv6_policy, extack); + if (err) + return err; + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG_MOD(extack, "rtm_src_len and rtm_dst_len must be 128 for IPv6"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_SRC: + case RTA_DST: + case RTA_IIF: + case RTA_OIF: + case RTA_MARK: + case RTA_UID: + case RTA_SPORT: + case RTA_DPORT: + case RTA_IP_PROTO: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request"); + return -EINVAL; + } + } + + return 0; +} + static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { @@ -4826,8 +4907,7 @@ static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct flowi6 fl6 = {}; bool fibmatch; - err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_ipv6_policy, - extack); + err = inet6_rtm_valid_getroute_req(in_skb, nlh, tb, extack); if (err < 0) goto errout; @@ -4873,7 +4953,8 @@ static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, if (tb[RTA_IP_PROTO]) { err = rtm_getroute_parse_ip_proto(tb[RTA_IP_PROTO], - &fl6.flowi6_proto, extack); + &fl6.flowi6_proto, AF_INET6, + extack); if (err) goto errout; } diff --git a/net/ipv6/seg6.c b/net/ipv6/seg6.c index 8d0ba757a46c..9b2f272ca164 100644 --- a/net/ipv6/seg6.c +++ b/net/ipv6/seg6.c @@ -221,9 +221,7 @@ static int seg6_genl_get_tunsrc(struct sk_buff *skb, struct genl_info *info) rcu_read_unlock(); genlmsg_end(msg, hdr); - genlmsg_reply(msg, info); - - return 0; + return genlmsg_reply(msg, info); nla_put_failure: rcu_read_unlock(); diff --git a/net/ipv6/seg6_iptunnel.c b/net/ipv6/seg6_iptunnel.c index 8181ee7e1e27..ee5403cbe655 100644 --- a/net/ipv6/seg6_iptunnel.c +++ b/net/ipv6/seg6_iptunnel.c @@ -146,6 +146,8 @@ int seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh, int proto) } else { ip6_flow_hdr(hdr, 0, flowlabel); hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb)); + + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); } hdr->nexthdr = NEXTHDR_ROUTING; diff --git a/net/ipv6/sit.c b/net/ipv6/sit.c index 1e03305c0549..07e21a82ce4c 100644 --- a/net/ipv6/sit.c +++ b/net/ipv6/sit.c @@ -546,7 +546,8 @@ static int ipip6_err(struct sk_buff *skb, u32 info) } err = 0; - if (!ip6_err_gen_icmpv6_unreach(skb, iph->ihl * 4, type, data_len)) + if (__in6_dev_get(skb->dev) && + !ip6_err_gen_icmpv6_unreach(skb, iph->ihl * 4, type, data_len)) goto out; if (t->parms.iph.daddr == 0) @@ -777,8 +778,9 @@ static bool check_6rd(struct ip_tunnel *tunnel, const struct in6_addr *v6dst, pbw0 = tunnel->ip6rd.prefixlen >> 5; pbi0 = tunnel->ip6rd.prefixlen & 0x1f; - d = (ntohl(v6dst->s6_addr32[pbw0]) << pbi0) >> - tunnel->ip6rd.relay_prefixlen; + d = tunnel->ip6rd.relay_prefixlen < 32 ? + (ntohl(v6dst->s6_addr32[pbw0]) << pbi0) >> + tunnel->ip6rd.relay_prefixlen : 0; pbi1 = pbi0 - tunnel->ip6rd.relay_prefixlen; if (pbi1 > 0) @@ -1872,6 +1874,7 @@ static int __net_init sit_init_net(struct net *net) err_reg_dev: ipip6_dev_free(sitn->fb_tunnel_dev); + free_netdev(sitn->fb_tunnel_dev); err_alloc_dev: return err; } diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index b81eb7cb815e..57ef69a10889 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -220,8 +220,6 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, u32 exthdrlen = icsk->icsk_ext_hdr_len; struct sockaddr_in sin; - SOCK_DEBUG(sk, "connect: ipv4 mapped\n"); - if (__ipv6_only_sock(sk)) return -ENETUNREACH; @@ -1864,7 +1862,7 @@ static void get_tcp6_sock(struct seq_file *seq, struct sock *sp, int i) refcount_read(&sp->sk_refcnt), sp, jiffies_to_clock_t(icsk->icsk_rto), jiffies_to_clock_t(icsk->icsk_ack.ato), - (icsk->icsk_ack.quick << 1) | icsk->icsk_ack.pingpong, + (icsk->icsk_ack.quick << 1) | inet_csk_in_pingpong_mode(sp), tp->snd_cwnd, state == TCP_LISTEN ? fastopenq->max_qlen : diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 2596ffdeebea..b444483cdb2b 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -288,8 +288,8 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int peeked, peeking, off; int err; int is_udplite = IS_UDPLITE(sk); + struct udp_mib __percpu *mib; bool checksum_valid = false; - struct udp_mib *mib; int is_udp4; if (flags & MSG_ERRQUEUE) @@ -420,17 +420,19 @@ EXPORT_SYMBOL(udpv6_encap_enable); */ static int __udp6_lib_err_encap_no_sk(struct sk_buff *skb, struct inet6_skb_parm *opt, - u8 type, u8 code, int offset, u32 info) + u8 type, u8 code, int offset, __be32 info) { int i; for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) { int (*handler)(struct sk_buff *skb, struct inet6_skb_parm *opt, - u8 type, u8 code, int offset, u32 info); + u8 type, u8 code, int offset, __be32 info); + const struct ip6_tnl_encap_ops *encap; - if (!ip6tun_encaps[i]) + encap = rcu_dereference(ip6tun_encaps[i]); + if (!encap) continue; - handler = rcu_dereference(ip6tun_encaps[i]->err_handler); + handler = encap->err_handler; if (handler && !handler(skb, opt, type, code, offset, info)) return 0; } diff --git a/net/ipv6/xfrm6_tunnel.c b/net/ipv6/xfrm6_tunnel.c index f5b4febeaa25..bc65db782bfb 100644 --- a/net/ipv6/xfrm6_tunnel.c +++ b/net/ipv6/xfrm6_tunnel.c @@ -344,8 +344,8 @@ static void __net_exit xfrm6_tunnel_net_exit(struct net *net) struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); unsigned int i; - xfrm_state_flush(net, IPSEC_PROTO_ANY, false); xfrm_flush_gc(); + xfrm_state_flush(net, IPSEC_PROTO_ANY, false, true); for (i = 0; i < XFRM6_TUNNEL_SPI_BYADDR_HSIZE; i++) WARN_ON_ONCE(!hlist_empty(&xfrm6_tn->spi_byaddr[i])); diff --git a/net/kcm/kcmsock.c b/net/kcm/kcmsock.c index 571d824e4e24..c5c5ab6c5a1c 100644 --- a/net/kcm/kcmsock.c +++ b/net/kcm/kcmsock.c @@ -2036,13 +2036,13 @@ static int __init kcm_init(void) kcm_muxp = kmem_cache_create("kcm_mux_cache", sizeof(struct kcm_mux), 0, - SLAB_HWCACHE_ALIGN | SLAB_PANIC, NULL); + SLAB_HWCACHE_ALIGN, NULL); if (!kcm_muxp) goto fail; kcm_psockp = kmem_cache_create("kcm_psock_cache", sizeof(struct kcm_psock), 0, - SLAB_HWCACHE_ALIGN | SLAB_PANIC, NULL); + SLAB_HWCACHE_ALIGN, NULL); if (!kcm_psockp) goto fail; diff --git a/net/key/af_key.c b/net/key/af_key.c index 655c787f9d54..5651c29cb5bd 100644 --- a/net/key/af_key.c +++ b/net/key/af_key.c @@ -196,30 +196,22 @@ static int pfkey_release(struct socket *sock) return 0; } -static int pfkey_broadcast_one(struct sk_buff *skb, struct sk_buff **skb2, - gfp_t allocation, struct sock *sk) +static int pfkey_broadcast_one(struct sk_buff *skb, gfp_t allocation, + struct sock *sk) { int err = -ENOBUFS; - sock_hold(sk); - if (*skb2 == NULL) { - if (refcount_read(&skb->users) != 1) { - *skb2 = skb_clone(skb, allocation); - } else { - *skb2 = skb; - refcount_inc(&skb->users); - } - } - if (*skb2 != NULL) { - if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) { - skb_set_owner_r(*skb2, sk); - skb_queue_tail(&sk->sk_receive_queue, *skb2); - sk->sk_data_ready(sk); - *skb2 = NULL; - err = 0; - } + if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf) + return err; + + skb = skb_clone(skb, allocation); + + if (skb) { + skb_set_owner_r(skb, sk); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + err = 0; } - sock_put(sk); return err; } @@ -234,7 +226,6 @@ static int pfkey_broadcast(struct sk_buff *skb, gfp_t allocation, { struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); struct sock *sk; - struct sk_buff *skb2 = NULL; int err = -ESRCH; /* XXX Do we need something like netlink_overrun? I think @@ -253,7 +244,7 @@ static int pfkey_broadcast(struct sk_buff *skb, gfp_t allocation, * socket. */ if (pfk->promisc) - pfkey_broadcast_one(skb, &skb2, GFP_ATOMIC, sk); + pfkey_broadcast_one(skb, GFP_ATOMIC, sk); /* the exact target will be processed later */ if (sk == one_sk) @@ -268,7 +259,7 @@ static int pfkey_broadcast(struct sk_buff *skb, gfp_t allocation, continue; } - err2 = pfkey_broadcast_one(skb, &skb2, GFP_ATOMIC, sk); + err2 = pfkey_broadcast_one(skb, GFP_ATOMIC, sk); /* Error is cleared after successful sending to at least one * registered KM */ @@ -278,9 +269,8 @@ static int pfkey_broadcast(struct sk_buff *skb, gfp_t allocation, rcu_read_unlock(); if (one_sk != NULL) - err = pfkey_broadcast_one(skb, &skb2, allocation, one_sk); + err = pfkey_broadcast_one(skb, allocation, one_sk); - kfree_skb(skb2); kfree_skb(skb); return err; } @@ -1783,7 +1773,7 @@ static int pfkey_flush(struct sock *sk, struct sk_buff *skb, const struct sadb_m if (proto == 0) return -EINVAL; - err = xfrm_state_flush(net, proto, true); + err = xfrm_state_flush(net, proto, true, false); err2 = unicast_flush_resp(sk, hdr); if (err || err2) { if (err == -ESRCH) /* empty table - go quietly */ diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 26f1d435696a..fed6becc5daf 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -83,8 +83,7 @@ #define L2TP_SLFLAG_S 0x40000000 #define L2TP_SL_SEQ_MASK 0x00ffffff -#define L2TP_HDR_SIZE_SEQ 10 -#define L2TP_HDR_SIZE_NOSEQ 6 +#define L2TP_HDR_SIZE_MAX 14 /* Default trace flags */ #define L2TP_DEFAULT_DEBUG_FLAGS 0 @@ -808,7 +807,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) __skb_pull(skb, sizeof(struct udphdr)); /* Short packet? */ - if (!pskb_may_pull(skb, L2TP_HDR_SIZE_SEQ)) { + if (!pskb_may_pull(skb, L2TP_HDR_SIZE_MAX)) { l2tp_info(tunnel, L2TP_MSG_DATA, "%s: recv short packet (len=%d)\n", tunnel->name, skb->len); @@ -884,6 +883,10 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) goto error; } + if (tunnel->version == L2TP_HDR_VER_3 && + l2tp_v3_ensure_opt_in_linear(session, skb, &ptr, &optr)) + goto error; + l2tp_recv_common(session, skb, ptr, optr, hdrflags, length); l2tp_session_dec_refcount(session); diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h index 9c9afe94d389..b2ce90260c35 100644 --- a/net/l2tp/l2tp_core.h +++ b/net/l2tp/l2tp_core.h @@ -301,6 +301,26 @@ static inline bool l2tp_tunnel_uses_xfrm(const struct l2tp_tunnel *tunnel) } #endif +static inline int l2tp_v3_ensure_opt_in_linear(struct l2tp_session *session, struct sk_buff *skb, + unsigned char **ptr, unsigned char **optr) +{ + int opt_len = session->peer_cookie_len + l2tp_get_l2specific_len(session); + + if (opt_len > 0) { + int off = *ptr - *optr; + + if (!pskb_may_pull(skb, off + opt_len)) + return -1; + + if (skb->data != *optr) { + *optr = skb->data; + *ptr = skb->data + off; + } + } + + return 0; +} + #define l2tp_printk(ptr, type, func, fmt, ...) \ do { \ if (((ptr)->debug) & (type)) \ diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c index 35f6f86d4dcc..d4c60523c549 100644 --- a/net/l2tp/l2tp_ip.c +++ b/net/l2tp/l2tp_ip.c @@ -165,6 +165,9 @@ static int l2tp_ip_recv(struct sk_buff *skb) print_hex_dump_bytes("", DUMP_PREFIX_OFFSET, ptr, length); } + if (l2tp_v3_ensure_opt_in_linear(session, skb, &ptr, &optr)) + goto discard_sess; + l2tp_recv_common(session, skb, ptr, optr, 0, skb->len); l2tp_session_dec_refcount(session); diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index 237f1a4a0b0c..37a69df17cab 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -178,6 +178,9 @@ static int l2tp_ip6_recv(struct sk_buff *skb) print_hex_dump_bytes("", DUMP_PREFIX_OFFSET, ptr, length); } + if (l2tp_v3_ensure_opt_in_linear(session, skb, &ptr, &optr)) + goto discard_sess; + l2tp_recv_common(session, skb, ptr, optr, 0, skb->len); l2tp_session_dec_refcount(session); @@ -671,9 +674,6 @@ static int l2tp_ip6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, if (flags & MSG_OOB) goto out; - if (addr_len) - *addr_len = sizeof(*lsa); - if (flags & MSG_ERRQUEUE) return ipv6_recv_error(sk, msg, len, addr_len); @@ -703,6 +703,7 @@ static int l2tp_ip6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, lsa->l2tp_conn_id = 0; if (ipv6_addr_type(&lsa->l2tp_addr) & IPV6_ADDR_LINKLOCAL) lsa->l2tp_scope_id = inet6_iif(skb); + *addr_len = sizeof(*lsa); } if (np->rxopt.all) diff --git a/net/mac80211/agg-tx.c b/net/mac80211/agg-tx.c index 69e831bc317b..2c4cd4183bf9 100644 --- a/net/mac80211/agg-tx.c +++ b/net/mac80211/agg-tx.c @@ -8,7 +8,7 @@ * Copyright 2007, Michael Wu <flamingice@sourmilk.net> * Copyright 2007-2010, Intel Corporation * Copyright(c) 2015-2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -229,7 +229,7 @@ ieee80211_agg_start_txq(struct sta_info *sta, int tid, bool enable) clear_bit(IEEE80211_TXQ_STOP, &txqi->flags); local_bh_disable(); rcu_read_lock(); - drv_wake_tx_queue(sta->sdata->local, txqi); + schedule_and_wake_txq(sta->sdata->local, txqi); rcu_read_unlock(); local_bh_enable(); } @@ -366,6 +366,8 @@ int ___ieee80211_stop_tx_ba_session(struct sta_info *sta, u16 tid, set_bit(HT_AGG_STATE_STOPPING, &tid_tx->state); + ieee80211_agg_stop_txq(sta, tid); + spin_unlock_bh(&sta->lock); ht_dbg(sta->sdata, "Tx BA session stop requested for %pM tid %u\n", diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c index de65fe3ed9cc..09dd1c2860fc 100644 --- a/net/mac80211/cfg.c +++ b/net/mac80211/cfg.c @@ -941,6 +941,7 @@ static int ieee80211_start_ap(struct wiphy *wiphy, struct net_device *dev, BSS_CHANGED_P2P_PS | BSS_CHANGED_TXPOWER; int err; + int prev_beacon_int; old = sdata_dereference(sdata->u.ap.beacon, sdata); if (old) @@ -963,6 +964,7 @@ static int ieee80211_start_ap(struct wiphy *wiphy, struct net_device *dev, sdata->needed_rx_chains = sdata->local->rx_chains; + prev_beacon_int = sdata->vif.bss_conf.beacon_int; sdata->vif.bss_conf.beacon_int = params->beacon_interval; if (params->he_cap) @@ -974,8 +976,10 @@ static int ieee80211_start_ap(struct wiphy *wiphy, struct net_device *dev, if (!err) ieee80211_vif_copy_chanctx_to_vlans(sdata, false); mutex_unlock(&local->mtx); - if (err) + if (err) { + sdata->vif.bss_conf.beacon_int = prev_beacon_int; return err; + } /* * Apply control port protocol, this allows us to @@ -1231,6 +1235,11 @@ static void sta_apply_mesh_params(struct ieee80211_local *local, ieee80211_mps_sta_status_update(sta); changed |= ieee80211_mps_set_sta_local_pm(sta, sdata->u.mesh.mshcfg.power_mode); + + ewma_mesh_tx_rate_avg_init(&sta->mesh->tx_rate_avg); + /* init at low value */ + ewma_mesh_tx_rate_avg_add(&sta->mesh->tx_rate_avg, 10); + break; case NL80211_PLINK_LISTEN: case NL80211_PLINK_BLOCKED: @@ -1447,6 +1456,9 @@ static int sta_apply_parameters(struct ieee80211_local *local, if (ieee80211_vif_is_mesh(&sdata->vif)) sta_apply_mesh_params(local, sta, params); + if (params->airtime_weight) + sta->airtime_weight = params->airtime_weight; + /* set the STA state after all sta info from usermode has been set */ if (test_sta_flag(sta, WLAN_STA_TDLS_PEER) || set & BIT(NL80211_STA_FLAG_ASSOCIATED)) { @@ -1490,6 +1502,10 @@ static int ieee80211_add_station(struct wiphy *wiphy, struct net_device *dev, if (params->sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER)) sta->sta.tdls = true; + if (sta->sta.tdls && sdata->vif.type == NL80211_IFTYPE_STATION && + !sdata->u.mgd.associated) + return -EINVAL; + err = sta_apply_parameters(local, sta, params); if (err) { sta_info_free(local, sta); @@ -1742,7 +1758,9 @@ static void mpath_set_pinfo(struct mesh_path *mpath, u8 *next_hop, MPATH_INFO_EXPTIME | MPATH_INFO_DISCOVERY_TIMEOUT | MPATH_INFO_DISCOVERY_RETRIES | - MPATH_INFO_FLAGS; + MPATH_INFO_FLAGS | + MPATH_INFO_HOP_COUNT | + MPATH_INFO_PATH_CHANGE; pinfo->frame_qlen = mpath->frame_queue.qlen; pinfo->sn = mpath->sn; @@ -1762,6 +1780,8 @@ static void mpath_set_pinfo(struct mesh_path *mpath, u8 *next_hop, pinfo->flags |= NL80211_MPATH_FLAG_FIXED; if (mpath->flags & MESH_PATH_RESOLVED) pinfo->flags |= NL80211_MPATH_FLAG_RESOLVED; + pinfo->hop_count = mpath->hop_count; + pinfo->path_change_count = mpath->path_change_count; } static int ieee80211_get_mpath(struct wiphy *wiphy, struct net_device *dev, diff --git a/net/mac80211/debugfs.c b/net/mac80211/debugfs.c index 3fe541e358f3..2d43bc127043 100644 --- a/net/mac80211/debugfs.c +++ b/net/mac80211/debugfs.c @@ -3,7 +3,7 @@ * * Copyright 2007 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * * GPLv2 * @@ -218,6 +218,9 @@ static const char *hw_flag_names[] = { FLAG(BUFF_MMPDU_TXQ), FLAG(SUPPORTS_VHT_EXT_NSS_BW), FLAG(STA_MMPDU_TXQ), + FLAG(TX_STATUS_NO_AMPDU_LEN), + FLAG(SUPPORTS_MULTI_BSSID), + FLAG(SUPPORTS_ONLY_HE_MULTI_BSSID), #undef FLAG }; @@ -383,6 +386,9 @@ void debugfs_hw_add(struct ieee80211_local *local) if (local->ops->wake_tx_queue) DEBUGFS_ADD_MODE(aqm, 0600); + debugfs_create_u16("airtime_flags", 0600, + phyd, &local->airtime_flags); + statsd = debugfs_create_dir("statistics", phyd); /* if the dir failed, don't put all the other things into the root! */ diff --git a/net/mac80211/debugfs_sta.c b/net/mac80211/debugfs_sta.c index b753194710ad..8e921281e0d5 100644 --- a/net/mac80211/debugfs_sta.c +++ b/net/mac80211/debugfs_sta.c @@ -4,7 +4,7 @@ * Copyright 2007 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright(c) 2016 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -181,9 +181,9 @@ static ssize_t sta_aqm_read(struct file *file, char __user *userbuf, txqi->tin.tx_bytes, txqi->tin.tx_packets, txqi->flags, - txqi->flags & (1<<IEEE80211_TXQ_STOP) ? "STOP" : "RUN", - txqi->flags & (1<<IEEE80211_TXQ_AMPDU) ? " AMPDU" : "", - txqi->flags & (1<<IEEE80211_TXQ_NO_AMSDU) ? " NO-AMSDU" : ""); + test_bit(IEEE80211_TXQ_STOP, &txqi->flags) ? "STOP" : "RUN", + test_bit(IEEE80211_TXQ_AMPDU, &txqi->flags) ? " AMPDU" : "", + test_bit(IEEE80211_TXQ_NO_AMSDU, &txqi->flags) ? " NO-AMSDU" : ""); } rcu_read_unlock(); @@ -195,6 +195,64 @@ static ssize_t sta_aqm_read(struct file *file, char __user *userbuf, } STA_OPS(aqm); +static ssize_t sta_airtime_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + struct ieee80211_local *local = sta->sdata->local; + size_t bufsz = 200; + char *buf = kzalloc(bufsz, GFP_KERNEL), *p = buf; + u64 rx_airtime = 0, tx_airtime = 0; + s64 deficit[IEEE80211_NUM_ACS]; + ssize_t rv; + int ac; + + if (!buf) + return -ENOMEM; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + spin_lock_bh(&local->active_txq_lock[ac]); + rx_airtime += sta->airtime[ac].rx_airtime; + tx_airtime += sta->airtime[ac].tx_airtime; + deficit[ac] = sta->airtime[ac].deficit; + spin_unlock_bh(&local->active_txq_lock[ac]); + } + + p += scnprintf(p, bufsz + buf - p, + "RX: %llu us\nTX: %llu us\nWeight: %u\n" + "Deficit: VO: %lld us VI: %lld us BE: %lld us BK: %lld us\n", + rx_airtime, + tx_airtime, + sta->airtime_weight, + deficit[0], + deficit[1], + deficit[2], + deficit[3]); + + rv = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return rv; +} + +static ssize_t sta_airtime_write(struct file *file, const char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + struct ieee80211_local *local = sta->sdata->local; + int ac; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + spin_lock_bh(&local->active_txq_lock[ac]); + sta->airtime[ac].rx_airtime = 0; + sta->airtime[ac].tx_airtime = 0; + sta->airtime[ac].deficit = sta->airtime_weight; + spin_unlock_bh(&local->active_txq_lock[ac]); + } + + return count; +} +STA_OPS_RW(airtime); + static ssize_t sta_agg_status_read(struct file *file, char __user *userbuf, size_t count, loff_t *ppos) { @@ -627,6 +685,9 @@ static ssize_t sta_he_capa_read(struct file *file, char __user *userbuf, "SUBCHAN-SELECVITE-TRANSMISSION"); PFLAG(MAC, 5, UL_2x996_TONE_RU, "UL-2x996-TONE-RU"); PFLAG(MAC, 5, OM_CTRL_UL_MU_DATA_DIS_RX, "OM-CTRL-UL-MU-DATA-DIS-RX"); + PFLAG(MAC, 5, HE_DYNAMIC_SM_PS, "HE-DYNAMIC-SM-PS"); + PFLAG(MAC, 5, PUNCTURED_SOUNDING, "PUNCTURED-SOUNDING"); + PFLAG(MAC, 5, HT_VHT_TRIG_FRAME_RX, "HT-VHT-TRIG-FRAME-RX"); cap = hec->he_cap_elem.phy_cap_info; p += scnprintf(p, buf_sz + buf - p, @@ -761,18 +822,18 @@ static ssize_t sta_he_capa_read(struct file *file, char __user *userbuf, PFLAG(PHY, 8, MIDAMBLE_RX_TX_2X_AND_1XLTF, "MIDAMBLE-RX-TX-2X-AND-1XLTF"); - switch (cap[8] & IEEE80211_HE_PHY_CAP8_DCM_MAX_BW_MASK) { - case IEEE80211_HE_PHY_CAP8_DCM_MAX_BW_20MHZ: - PRINT("DDCM-MAX-BW-20MHZ"); + switch (cap[8] & IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_MASK) { + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_242: + PRINT("DCM-MAX-RU-242"); break; - case IEEE80211_HE_PHY_CAP8_DCM_MAX_BW_40MHZ: - PRINT("DCM-MAX-BW-40MHZ"); + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_484: + PRINT("DCM-MAX-RU-484"); break; - case IEEE80211_HE_PHY_CAP8_DCM_MAX_BW_80MHZ: - PRINT("DCM-MAX-BW-80MHZ"); + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_996: + PRINT("DCM-MAX-RU-996"); break; - case IEEE80211_HE_PHY_CAP8_DCM_MAX_BW_160_OR_80P80_MHZ: - PRINT("DCM-MAX-BW-160-OR-80P80-MHZ"); + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_2x996: + PRINT("DCM-MAX-RU-2x996"); break; } @@ -789,6 +850,18 @@ static ssize_t sta_he_capa_read(struct file *file, char __user *userbuf, PFLAG(PHY, 9, RX_FULL_BW_SU_USING_MU_WITH_NON_COMP_SIGB, "RX-FULL-BW-SU-USING-MU-WITH-NON-COMP-SIGB"); + switch (cap[9] & IEEE80211_HE_PHY_CAP9_NOMIMAL_PKT_PADDING_MASK) { + case IEEE80211_HE_PHY_CAP9_NOMIMAL_PKT_PADDING_0US: + PRINT("NOMINAL-PACKET-PADDING-0US"); + break; + case IEEE80211_HE_PHY_CAP9_NOMIMAL_PKT_PADDING_8US: + PRINT("NOMINAL-PACKET-PADDING-8US"); + break; + case IEEE80211_HE_PHY_CAP9_NOMIMAL_PKT_PADDING_16US: + PRINT("NOMINAL-PACKET-PADDING-16US"); + break; + } + #undef PFLAG_RANGE_DEFAULT #undef PFLAG_RANGE #undef PFLAG @@ -906,6 +979,10 @@ void ieee80211_sta_debugfs_add(struct sta_info *sta) if (local->ops->wake_tx_queue) DEBUGFS_ADD(aqm); + if (wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + DEBUGFS_ADD(airtime); + if (sizeof(sta->driver_buffered_tids) == sizeof(u32)) debugfs_create_x32("driver_buffered_tids", 0400, sta->debugfs_dir, diff --git a/net/mac80211/driver-ops.h b/net/mac80211/driver-ops.h index 3e0d5922a440..28d022a3eee3 100644 --- a/net/mac80211/driver-ops.h +++ b/net/mac80211/driver-ops.h @@ -2,7 +2,7 @@ /* * Portions of this file * Copyright(c) 2016 Intel Deutschland GmbH -* Copyright (C) 2018 Intel Corporation +* Copyright (C) 2018 - 2019 Intel Corporation */ #ifndef __MAC80211_DRIVER_OPS @@ -1052,6 +1052,35 @@ drv_post_channel_switch(struct ieee80211_sub_if_data *sdata) return ret; } +static inline void +drv_abort_channel_switch(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_abort_channel_switch(local, sdata); + + if (local->ops->abort_channel_switch) + local->ops->abort_channel_switch(&local->hw, &sdata->vif); +} + +static inline void +drv_channel_switch_rx_beacon(struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch) +{ + struct ieee80211_local *local = sdata->local; + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_channel_switch_rx_beacon(local, sdata, ch_switch); + if (local->ops->channel_switch_rx_beacon) + local->ops->channel_switch_rx_beacon(&local->hw, &sdata->vif, + ch_switch); +} + static inline int drv_join_ibss(struct ieee80211_local *local, struct ieee80211_sub_if_data *sdata) { @@ -1173,6 +1202,13 @@ static inline void drv_wake_tx_queue(struct ieee80211_local *local, local->ops->wake_tx_queue(&local->hw, &txq->txq); } +static inline void schedule_and_wake_txq(struct ieee80211_local *local, + struct txq_info *txqi) +{ + ieee80211_schedule_txq(&local->hw, &txqi->txq); + drv_wake_tx_queue(local, txqi); +} + static inline int drv_can_aggregate_in_amsdu(struct ieee80211_local *local, struct sk_buff *head, struct sk_buff *skb) diff --git a/net/mac80211/ht.c b/net/mac80211/ht.c index f849ea814993..e03c46ac8e4d 100644 --- a/net/mac80211/ht.c +++ b/net/mac80211/ht.c @@ -107,6 +107,14 @@ void ieee80211_apply_htcap_overrides(struct ieee80211_sub_if_data *sdata, __check_htcap_enable(ht_capa, ht_capa_mask, ht_cap, IEEE80211_HT_CAP_40MHZ_INTOLERANT); + /* Allow user to enable TX STBC bit */ + __check_htcap_enable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_TX_STBC); + + /* Allow user to configure RX STBC bits */ + if (ht_capa_mask->cap_info & IEEE80211_HT_CAP_RX_STBC) + ht_cap->cap |= ht_capa->cap_info & IEEE80211_HT_CAP_RX_STBC; + /* Allow user to decrease AMPDU factor */ if (ht_capa_mask->ampdu_params_info & IEEE80211_HT_AMPDU_PARM_FACTOR) { diff --git a/net/mac80211/ibss.c b/net/mac80211/ibss.c index 0d704e8d7078..4e4507115cf3 100644 --- a/net/mac80211/ibss.c +++ b/net/mac80211/ibss.c @@ -8,6 +8,7 @@ * Copyright 2009, Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright(c) 2016 Intel Deutschland GmbH + * Copyright(c) 2018-2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -1124,8 +1125,7 @@ static void ieee80211_rx_bss_info(struct ieee80211_sub_if_data *sdata, ieee80211_update_sta_info(sdata, mgmt, len, rx_status, elems, channel); - bss = ieee80211_bss_info_update(local, rx_status, mgmt, len, elems, - channel); + bss = ieee80211_bss_info_update(local, rx_status, mgmt, len, channel); if (!bss) return; @@ -1604,7 +1604,7 @@ void ieee80211_rx_mgmt_probe_beacon(struct ieee80211_sub_if_data *sdata, return; ieee802_11_parse_elems(mgmt->u.probe_resp.variable, len - baselen, - false, &elems); + false, &elems, mgmt->bssid, NULL); ieee80211_rx_bss_info(sdata, mgmt, len, rx_status, &elems); } @@ -1654,7 +1654,7 @@ void ieee80211_ibss_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, ieee802_11_parse_elems( mgmt->u.action.u.chan_switch.variable, - ies_len, true, &elems); + ies_len, true, &elems, mgmt->bssid, NULL); if (elems.parse_error) break; diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h index 7dfb4e2f98b2..e170f986d226 100644 --- a/net/mac80211/ieee80211_i.h +++ b/net/mac80211/ieee80211_i.h @@ -4,7 +4,7 @@ * Copyright 2006-2007 Jiri Benc <jbenc@suse.cz> * Copyright 2007-2010 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2015 Intel Mobile Communications GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018-2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -556,6 +556,12 @@ struct ieee80211_if_managed { * get stuck in a downgraded situation and flush takes forever. */ struct delayed_work tx_tspec_wk; + + /* Information elements from the last transmitted (Re)Association + * Request frame. + */ + u8 *assoc_req_ies; + size_t assoc_req_ies_len; }; struct ieee80211_if_ibss { @@ -831,6 +837,8 @@ enum txq_info_flags { * a fq_flow which is already owned by a different tin * @def_cvars: codel vars for @def_flow * @frags: used to keep fragments created after dequeue + * @schedule_order: used with ieee80211_local->active_txqs + * @schedule_round: counter to prevent infinite loops on TXQ scheduling */ struct txq_info { struct fq_tin tin; @@ -838,6 +846,8 @@ struct txq_info { struct codel_vars def_cvars; struct codel_stats cstats; struct sk_buff_head frags; + struct list_head schedule_order; + u16 schedule_round; unsigned long flags; /* keep last! */ @@ -1129,6 +1139,13 @@ struct ieee80211_local { struct codel_vars *cvars; struct codel_params cparams; + /* protects active_txqs and txqi->schedule_order */ + spinlock_t active_txq_lock[IEEE80211_NUM_ACS]; + struct list_head active_txqs[IEEE80211_NUM_ACS]; + u16 schedule_round[IEEE80211_NUM_ACS]; + + u16 airtime_flags; + const struct ieee80211_ops *ops; /* @@ -1436,6 +1453,7 @@ struct ieee80211_csa_ie { u8 ttl; u16 pre_value; u16 reason_code; + u32 max_switch_time; }; /* Parsed Information Elements */ @@ -1476,6 +1494,7 @@ struct ieee802_11_elems { const struct ieee80211_channel_sw_ie *ch_switch_ie; const struct ieee80211_ext_chansw_ie *ext_chansw_ie; const struct ieee80211_wide_bw_chansw_ie *wide_bw_chansw_ie; + const u8 *max_channel_switch_time; const u8 *country_elem; const u8 *pwr_constr_elem; const u8 *cisco_dtpc_elem; @@ -1484,6 +1503,12 @@ struct ieee802_11_elems { const struct ieee80211_sec_chan_offs_ie *sec_chan_offs; struct ieee80211_mesh_chansw_params_ie *mesh_chansw_params_ie; const struct ieee80211_bss_max_idle_period_ie *max_idle_period_ie; + const struct ieee80211_multiple_bssid_configuration *mbssid_config_ie; + const struct ieee80211_bssid_index *bssid_index; + const u8 *nontransmitted_bssid_profile; + u8 max_bssid_indicator; + u8 dtim_count; + u8 dtim_period; /* length of them, respectively */ u8 ext_capab_len; @@ -1502,6 +1527,7 @@ struct ieee802_11_elems { u8 prep_len; u8 perr_len; u8 country_elem_len; + u8 bssid_index_len; /* whether a parse error occurred while retrieving these elements */ bool parse_error; @@ -1661,7 +1687,6 @@ ieee80211_bss_info_update(struct ieee80211_local *local, struct ieee80211_rx_status *rx_status, struct ieee80211_mgmt *mgmt, size_t len, - struct ieee802_11_elems *elems, struct ieee80211_channel *channel); void ieee80211_rx_bss_put(struct ieee80211_local *local, struct ieee80211_bss *bss); @@ -1945,12 +1970,16 @@ static inline void ieee80211_tx_skb(struct ieee80211_sub_if_data *sdata, u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, struct ieee802_11_elems *elems, - u64 filter, u32 crc); + u64 filter, u32 crc, u8 *transmitter_bssid, + u8 *bss_bssid); static inline void ieee802_11_parse_elems(const u8 *start, size_t len, bool action, - struct ieee802_11_elems *elems) + struct ieee802_11_elems *elems, + u8 *transmitter_bssid, + u8 *bss_bssid) { - ieee802_11_parse_elems_crc(start, len, action, elems, 0, 0); + ieee802_11_parse_elems_crc(start, len, action, elems, 0, 0, + transmitter_bssid, bss_bssid); } diff --git a/net/mac80211/main.c b/net/mac80211/main.c index 87a729926734..800e67615e2a 100644 --- a/net/mac80211/main.c +++ b/net/mac80211/main.c @@ -4,7 +4,7 @@ * Copyright 2006-2007 Jiri Benc <jbenc@suse.cz> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright (C) 2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -478,6 +478,8 @@ static const struct ieee80211_ht_cap mac80211_ht_capa_mod_mask = { IEEE80211_HT_CAP_MAX_AMSDU | IEEE80211_HT_CAP_SGI_20 | IEEE80211_HT_CAP_SGI_40 | + IEEE80211_HT_CAP_TX_STBC | + IEEE80211_HT_CAP_RX_STBC | IEEE80211_HT_CAP_LDPC_CODING | IEEE80211_HT_CAP_40MHZ_INTOLERANT), .mcs = { @@ -615,13 +617,13 @@ struct ieee80211_hw *ieee80211_alloc_hw_nm(size_t priv_data_len, * We need a bit of data queued to build aggregates properly, so * instruct the TCP stack to allow more than a single ms of data * to be queued in the stack. The value is a bit-shift of 1 - * second, so 8 is ~4ms of queued data. Only affects local TCP + * second, so 7 is ~8ms of queued data. Only affects local TCP * sockets. * This is the default, anyhow - drivers may need to override it * for local reasons (longer buffers, longer completion time, or * similar). */ - local->hw.tx_sk_pacing_shift = 8; + local->hw.tx_sk_pacing_shift = 7; /* set up some defaults */ local->hw.queues = 1; @@ -663,6 +665,12 @@ struct ieee80211_hw *ieee80211_alloc_hw_nm(size_t priv_data_len, spin_lock_init(&local->rx_path_lock); spin_lock_init(&local->queue_stop_reason_lock); + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + INIT_LIST_HEAD(&local->active_txqs[i]); + spin_lock_init(&local->active_txq_lock[i]); + } + local->airtime_flags = AIRTIME_USE_TX | AIRTIME_USE_RX; + INIT_LIST_HEAD(&local->chanctx_list); mutex_init(&local->chanctx_mtx); @@ -1104,6 +1112,17 @@ int ieee80211_register_hw(struct ieee80211_hw *hw) if (ieee80211_hw_check(&local->hw, CHANCTX_STA_CSA)) local->ext_capa[0] |= WLAN_EXT_CAPA1_EXT_CHANNEL_SWITCHING; + /* mac80211 supports multi BSSID, if the driver supports it */ + if (ieee80211_hw_check(&local->hw, SUPPORTS_MULTI_BSSID)) { + local->hw.wiphy->support_mbssid = true; + if (ieee80211_hw_check(&local->hw, + SUPPORTS_ONLY_HE_MULTI_BSSID)) + local->hw.wiphy->support_only_he_mbssid = true; + else + local->ext_capa[2] |= + WLAN_EXT_CAPA3_MULTI_BSSID_SUPPORT; + } + local->hw.wiphy->max_num_csa_counters = IEEE80211_MAX_CSA_COUNTERS_NUM; result = wiphy_register(local->hw.wiphy); @@ -1148,6 +1167,9 @@ int ieee80211_register_hw(struct ieee80211_hw *hw) if (!local->hw.max_nan_de_entries) local->hw.max_nan_de_entries = IEEE80211_MAX_NAN_INSTANCE_ID; + if (!local->hw.weight_multiplier) + local->hw.weight_multiplier = 1; + result = ieee80211_wep_init(local); if (result < 0) wiphy_debug(local->hw.wiphy, "Failed to initialize wep: %d\n", diff --git a/net/mac80211/mesh.c b/net/mac80211/mesh.c index c90452aa0c42..766e5e5bab8a 100644 --- a/net/mac80211/mesh.c +++ b/net/mac80211/mesh.c @@ -1,6 +1,6 @@ /* * Copyright (c) 2008, 2009 open80211s Ltd. - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * Authors: Luis Carlos Cobo <luisca@cozybit.com> * Javier Cardona <javier@cozybit.com> * @@ -1106,7 +1106,8 @@ ieee80211_mesh_rx_probe_req(struct ieee80211_sub_if_data *sdata, if (baselen > len) return; - ieee802_11_parse_elems(pos, len - baselen, false, &elems); + ieee802_11_parse_elems(pos, len - baselen, false, &elems, mgmt->bssid, + NULL); if (!elems.mesh_id) return; @@ -1170,7 +1171,7 @@ static void ieee80211_mesh_rx_bcn_presp(struct ieee80211_sub_if_data *sdata, return; ieee802_11_parse_elems(mgmt->u.probe_resp.variable, len - baselen, - false, &elems); + false, &elems, mgmt->bssid, NULL); /* ignore non-mesh or secure / unsecure mismatch */ if ((!elems.mesh_id || !elems.mesh_config) || @@ -1306,7 +1307,8 @@ static void mesh_rx_csa_frame(struct ieee80211_sub_if_data *sdata, pos = mgmt->u.action.u.chan_switch.variable; baselen = offsetof(struct ieee80211_mgmt, u.action.u.chan_switch.variable); - ieee802_11_parse_elems(pos, len - baselen, true, &elems); + ieee802_11_parse_elems(pos, len - baselen, true, &elems, + mgmt->bssid, NULL); ifmsh->chsw_ttl = elems.mesh_chansw_params_ie->mesh_ttl; if (!--ifmsh->chsw_ttl) diff --git a/net/mac80211/mesh.h b/net/mac80211/mesh.h index cad6592c52a1..574c3891c4b2 100644 --- a/net/mac80211/mesh.h +++ b/net/mac80211/mesh.h @@ -70,6 +70,7 @@ enum mesh_deferred_task_flags { * @dst: mesh path destination mac address * @mpp: mesh proxy mac address * @rhash: rhashtable list pointer + * @walk_list: linked list containing all mesh_path objects. * @gate_list: list pointer for known gates list * @sdata: mesh subif * @next_hop: mesh neighbor to which frames for this destination will be @@ -94,6 +95,7 @@ enum mesh_deferred_task_flags { * @last_preq_to_root: Timestamp of last PREQ sent to root * @is_root: the destination station of this path is a root node * @is_gate: the destination station of this path is a mesh gate + * @path_change_count: the number of path changes to destination * * * The dst address is unique in the mesh path table. Since the mesh_path is @@ -105,6 +107,7 @@ struct mesh_path { u8 dst[ETH_ALEN]; u8 mpp[ETH_ALEN]; /* used for MPP or MAP */ struct rhash_head rhash; + struct hlist_node walk_list; struct hlist_node gate_list; struct ieee80211_sub_if_data *sdata; struct sta_info __rcu *next_hop; @@ -124,6 +127,7 @@ struct mesh_path { unsigned long last_preq_to_root; bool is_root; bool is_gate; + u32 path_change_count; }; /** @@ -133,12 +137,16 @@ struct mesh_path { * gate's mpath may or may not be resolved and active. * @gates_lock: protects updates to known_gates * @rhead: the rhashtable containing struct mesh_paths, keyed by dest addr + * @walk_head: linked list containging all mesh_path objects + * @walk_lock: lock protecting walk_head * @entries: number of entries in the table */ struct mesh_table { struct hlist_head known_gates; spinlock_t gates_lock; struct rhashtable rhead; + struct hlist_head walk_head; + spinlock_t walk_lock; atomic_t entries; /* Up to MAX_MESH_NEIGHBOURS */ }; diff --git a/net/mac80211/mesh_hwmp.c b/net/mac80211/mesh_hwmp.c index 6950cd0bf594..f7517668e77a 100644 --- a/net/mac80211/mesh_hwmp.c +++ b/net/mac80211/mesh_hwmp.c @@ -1,5 +1,6 @@ /* * Copyright (c) 2008, 2009 open80211s Ltd. + * Copyright (C) 2019 Intel Corporation * Author: Luis Carlos Cobo <luisca@cozybit.com> * * This program is free software; you can redistribute it and/or modify @@ -300,6 +301,7 @@ void ieee80211s_update_metric(struct ieee80211_local *local, { struct ieee80211_tx_info *txinfo = st->info; int failed; + struct rate_info rinfo; failed = !(txinfo->flags & IEEE80211_TX_STAT_ACK); @@ -310,12 +312,15 @@ void ieee80211s_update_metric(struct ieee80211_local *local, if (ewma_mesh_fail_avg_read(&sta->mesh->fail_avg) > LINK_FAIL_THRESH) mesh_plink_broken(sta); + + sta_set_rate_info_tx(sta, &sta->tx_stats.last_rate, &rinfo); + ewma_mesh_tx_rate_avg_add(&sta->mesh->tx_rate_avg, + cfg80211_calculate_bitrate(&rinfo)); } static u32 airtime_link_metric_get(struct ieee80211_local *local, struct sta_info *sta) { - struct rate_info rinfo; /* This should be adjusted for each device */ int device_constant = 1 << ARITH_SHIFT; int test_frame_len = TEST_FRAME_LEN << ARITH_SHIFT; @@ -339,8 +344,7 @@ static u32 airtime_link_metric_get(struct ieee80211_local *local, if (fail_avg > LINK_FAIL_THRESH) return MAX_METRIC; - sta_set_rate_info_tx(sta, &sta->tx_stats.last_rate, &rinfo); - rate = cfg80211_calculate_bitrate(&rinfo); + rate = ewma_mesh_tx_rate_avg_read(&sta->mesh->tx_rate_avg); if (WARN_ON(!rate)) return MAX_METRIC; @@ -386,6 +390,7 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, unsigned long orig_lifetime, exp_time; u32 last_hop_metric, new_metric; bool process = true; + u8 hopcount; rcu_read_lock(); sta = sta_info_get(sdata, mgmt->sa); @@ -404,6 +409,7 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, orig_sn = PREQ_IE_ORIG_SN(hwmp_ie); orig_lifetime = PREQ_IE_LIFETIME(hwmp_ie); orig_metric = PREQ_IE_METRIC(hwmp_ie); + hopcount = PREQ_IE_HOPCOUNT(hwmp_ie) + 1; break; case MPATH_PREP: /* Originator here refers to the MP that was the target in the @@ -415,6 +421,7 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, orig_sn = PREP_IE_TARGET_SN(hwmp_ie); orig_lifetime = PREP_IE_LIFETIME(hwmp_ie); orig_metric = PREP_IE_METRIC(hwmp_ie); + hopcount = PREP_IE_HOPCOUNT(hwmp_ie) + 1; break; default: rcu_read_unlock(); @@ -441,7 +448,10 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, (mpath->flags & MESH_PATH_SN_VALID)) { if (SN_GT(mpath->sn, orig_sn) || (mpath->sn == orig_sn && - new_metric >= mpath->metric)) { + (rcu_access_pointer(mpath->next_hop) != + sta ? + mult_frac(new_metric, 10, 9) : + new_metric) >= mpath->metric)) { process = false; fresh_info = false; } @@ -476,12 +486,15 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, } if (fresh_info) { + if (rcu_access_pointer(mpath->next_hop) != sta) + mpath->path_change_count++; mesh_path_assign_nexthop(mpath, sta); mpath->flags |= MESH_PATH_SN_VALID; mpath->metric = new_metric; mpath->sn = orig_sn; mpath->exp_time = time_after(mpath->exp_time, exp_time) ? mpath->exp_time : exp_time; + mpath->hop_count = hopcount; mesh_path_activate(mpath); spin_unlock_bh(&mpath->state_lock); ewma_mesh_fail_avg_init(&sta->mesh->fail_avg); @@ -506,8 +519,10 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, if (mpath) { spin_lock_bh(&mpath->state_lock); if ((mpath->flags & MESH_PATH_FIXED) || - ((mpath->flags & MESH_PATH_ACTIVE) && - (last_hop_metric > mpath->metric))) + ((mpath->flags & MESH_PATH_ACTIVE) && + ((rcu_access_pointer(mpath->next_hop) != sta ? + mult_frac(last_hop_metric, 10, 9) : + last_hop_metric) > mpath->metric))) fresh_info = false; } else { mpath = mesh_path_add(sdata, ta); @@ -519,10 +534,13 @@ static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, } if (fresh_info) { + if (rcu_access_pointer(mpath->next_hop) != sta) + mpath->path_change_count++; mesh_path_assign_nexthop(mpath, sta); mpath->metric = last_hop_metric; mpath->exp_time = time_after(mpath->exp_time, exp_time) ? mpath->exp_time : exp_time; + mpath->hop_count = 1; mesh_path_activate(mpath); spin_unlock_bh(&mpath->state_lock); ewma_mesh_fail_avg_init(&sta->mesh->fail_avg); @@ -909,7 +927,7 @@ void mesh_rx_path_sel_frame(struct ieee80211_sub_if_data *sdata, baselen = (u8 *) mgmt->u.action.u.mesh_action.variable - (u8 *) mgmt; ieee802_11_parse_elems(mgmt->u.action.u.mesh_action.variable, - len - baselen, false, &elems); + len - baselen, false, &elems, mgmt->bssid, NULL); if (elems.preq) { if (elems.preq_len != 37) diff --git a/net/mac80211/mesh_pathtbl.c b/net/mac80211/mesh_pathtbl.c index a5125624a76d..95eb5064fa91 100644 --- a/net/mac80211/mesh_pathtbl.c +++ b/net/mac80211/mesh_pathtbl.c @@ -59,8 +59,10 @@ static struct mesh_table *mesh_table_alloc(void) return NULL; INIT_HLIST_HEAD(&newtbl->known_gates); + INIT_HLIST_HEAD(&newtbl->walk_head); atomic_set(&newtbl->entries, 0); spin_lock_init(&newtbl->gates_lock); + spin_lock_init(&newtbl->walk_lock); return newtbl; } @@ -249,28 +251,15 @@ mpp_path_lookup(struct ieee80211_sub_if_data *sdata, const u8 *dst) static struct mesh_path * __mesh_path_lookup_by_idx(struct mesh_table *tbl, int idx) { - int i = 0, ret; - struct mesh_path *mpath = NULL; - struct rhashtable_iter iter; - - ret = rhashtable_walk_init(&tbl->rhead, &iter, GFP_ATOMIC); - if (ret) - return NULL; - - rhashtable_walk_start(&iter); + int i = 0; + struct mesh_path *mpath; - while ((mpath = rhashtable_walk_next(&iter))) { - if (IS_ERR(mpath) && PTR_ERR(mpath) == -EAGAIN) - continue; - if (IS_ERR(mpath)) - break; + hlist_for_each_entry_rcu(mpath, &tbl->walk_head, walk_list) { if (i++ == idx) break; } - rhashtable_walk_stop(&iter); - rhashtable_walk_exit(&iter); - if (IS_ERR(mpath) || !mpath) + if (!mpath) return NULL; if (mpath_expired(mpath)) { @@ -415,7 +404,6 @@ struct mesh_path *mesh_path_add(struct ieee80211_sub_if_data *sdata, { struct mesh_table *tbl; struct mesh_path *mpath, *new_mpath; - int ret; if (ether_addr_equal(dst, sdata->vif.addr)) /* never add ourselves as neighbours */ @@ -432,29 +420,23 @@ struct mesh_path *mesh_path_add(struct ieee80211_sub_if_data *sdata, return ERR_PTR(-ENOMEM); tbl = sdata->u.mesh.mesh_paths; - do { - ret = rhashtable_lookup_insert_fast(&tbl->rhead, - &new_mpath->rhash, - mesh_rht_params); - - if (ret == -EEXIST) - mpath = rhashtable_lookup_fast(&tbl->rhead, - dst, - mesh_rht_params); - - } while (unlikely(ret == -EEXIST && !mpath)); - - if (ret && ret != -EEXIST) - return ERR_PTR(ret); - - /* At this point either new_mpath was added, or we found a - * matching entry already in the table; in the latter case - * free the unnecessary new entry. - */ - if (ret == -EEXIST) { + spin_lock_bh(&tbl->walk_lock); + mpath = rhashtable_lookup_get_insert_fast(&tbl->rhead, + &new_mpath->rhash, + mesh_rht_params); + if (!mpath) + hlist_add_head(&new_mpath->walk_list, &tbl->walk_head); + spin_unlock_bh(&tbl->walk_lock); + + if (mpath) { kfree(new_mpath); + + if (IS_ERR(mpath)) + return mpath; + new_mpath = mpath; } + sdata->u.mesh.mesh_paths_generation++; return new_mpath; } @@ -480,9 +462,17 @@ int mpp_path_add(struct ieee80211_sub_if_data *sdata, memcpy(new_mpath->mpp, mpp, ETH_ALEN); tbl = sdata->u.mesh.mpp_paths; + + spin_lock_bh(&tbl->walk_lock); ret = rhashtable_lookup_insert_fast(&tbl->rhead, &new_mpath->rhash, mesh_rht_params); + if (!ret) + hlist_add_head_rcu(&new_mpath->walk_list, &tbl->walk_head); + spin_unlock_bh(&tbl->walk_lock); + + if (ret) + kfree(new_mpath); sdata->u.mesh.mpp_paths_generation++; return ret; @@ -503,20 +493,9 @@ void mesh_plink_broken(struct sta_info *sta) struct mesh_table *tbl = sdata->u.mesh.mesh_paths; static const u8 bcast[ETH_ALEN] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; struct mesh_path *mpath; - struct rhashtable_iter iter; - int ret; - - ret = rhashtable_walk_init(&tbl->rhead, &iter, GFP_ATOMIC); - if (ret) - return; - - rhashtable_walk_start(&iter); - while ((mpath = rhashtable_walk_next(&iter))) { - if (IS_ERR(mpath) && PTR_ERR(mpath) == -EAGAIN) - continue; - if (IS_ERR(mpath)) - break; + rcu_read_lock(); + hlist_for_each_entry_rcu(mpath, &tbl->walk_head, walk_list) { if (rcu_access_pointer(mpath->next_hop) == sta && mpath->flags & MESH_PATH_ACTIVE && !(mpath->flags & MESH_PATH_FIXED)) { @@ -530,8 +509,7 @@ void mesh_plink_broken(struct sta_info *sta) WLAN_REASON_MESH_PATH_DEST_UNREACHABLE, bcast); } } - rhashtable_walk_stop(&iter); - rhashtable_walk_exit(&iter); + rcu_read_unlock(); } static void mesh_path_free_rcu(struct mesh_table *tbl, @@ -551,6 +529,7 @@ static void mesh_path_free_rcu(struct mesh_table *tbl, static void __mesh_path_del(struct mesh_table *tbl, struct mesh_path *mpath) { + hlist_del_rcu(&mpath->walk_list); rhashtable_remove_fast(&tbl->rhead, &mpath->rhash, mesh_rht_params); mesh_path_free_rcu(tbl, mpath); } @@ -571,27 +550,14 @@ void mesh_path_flush_by_nexthop(struct sta_info *sta) struct ieee80211_sub_if_data *sdata = sta->sdata; struct mesh_table *tbl = sdata->u.mesh.mesh_paths; struct mesh_path *mpath; - struct rhashtable_iter iter; - int ret; - - ret = rhashtable_walk_init(&tbl->rhead, &iter, GFP_ATOMIC); - if (ret) - return; - - rhashtable_walk_start(&iter); - - while ((mpath = rhashtable_walk_next(&iter))) { - if (IS_ERR(mpath) && PTR_ERR(mpath) == -EAGAIN) - continue; - if (IS_ERR(mpath)) - break; + struct hlist_node *n; + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { if (rcu_access_pointer(mpath->next_hop) == sta) __mesh_path_del(tbl, mpath); } - - rhashtable_walk_stop(&iter); - rhashtable_walk_exit(&iter); + spin_unlock_bh(&tbl->walk_lock); } static void mpp_flush_by_proxy(struct ieee80211_sub_if_data *sdata, @@ -599,51 +565,26 @@ static void mpp_flush_by_proxy(struct ieee80211_sub_if_data *sdata, { struct mesh_table *tbl = sdata->u.mesh.mpp_paths; struct mesh_path *mpath; - struct rhashtable_iter iter; - int ret; - - ret = rhashtable_walk_init(&tbl->rhead, &iter, GFP_ATOMIC); - if (ret) - return; - - rhashtable_walk_start(&iter); - - while ((mpath = rhashtable_walk_next(&iter))) { - if (IS_ERR(mpath) && PTR_ERR(mpath) == -EAGAIN) - continue; - if (IS_ERR(mpath)) - break; + struct hlist_node *n; + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { if (ether_addr_equal(mpath->mpp, proxy)) __mesh_path_del(tbl, mpath); } - - rhashtable_walk_stop(&iter); - rhashtable_walk_exit(&iter); + spin_unlock_bh(&tbl->walk_lock); } static void table_flush_by_iface(struct mesh_table *tbl) { struct mesh_path *mpath; - struct rhashtable_iter iter; - int ret; - - ret = rhashtable_walk_init(&tbl->rhead, &iter, GFP_ATOMIC); - if (ret) - return; + struct hlist_node *n; - rhashtable_walk_start(&iter); - - while ((mpath = rhashtable_walk_next(&iter))) { - if (IS_ERR(mpath) && PTR_ERR(mpath) == -EAGAIN) - continue; - if (IS_ERR(mpath)) - break; + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { __mesh_path_del(tbl, mpath); } - - rhashtable_walk_stop(&iter); - rhashtable_walk_exit(&iter); + spin_unlock_bh(&tbl->walk_lock); } /** @@ -675,15 +616,15 @@ static int table_path_del(struct mesh_table *tbl, { struct mesh_path *mpath; - rcu_read_lock(); + spin_lock_bh(&tbl->walk_lock); mpath = rhashtable_lookup_fast(&tbl->rhead, addr, mesh_rht_params); if (!mpath) { - rcu_read_unlock(); + spin_unlock_bh(&tbl->walk_lock); return -ENXIO; } __mesh_path_del(tbl, mpath); - rcu_read_unlock(); + spin_unlock_bh(&tbl->walk_lock); return 0; } @@ -854,28 +795,16 @@ void mesh_path_tbl_expire(struct ieee80211_sub_if_data *sdata, struct mesh_table *tbl) { struct mesh_path *mpath; - struct rhashtable_iter iter; - int ret; + struct hlist_node *n; - ret = rhashtable_walk_init(&tbl->rhead, &iter, GFP_KERNEL); - if (ret) - return; - - rhashtable_walk_start(&iter); - - while ((mpath = rhashtable_walk_next(&iter))) { - if (IS_ERR(mpath) && PTR_ERR(mpath) == -EAGAIN) - continue; - if (IS_ERR(mpath)) - break; + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { if ((!(mpath->flags & MESH_PATH_RESOLVING)) && (!(mpath->flags & MESH_PATH_FIXED)) && time_after(jiffies, mpath->exp_time + MESH_PATH_EXPIRE)) __mesh_path_del(tbl, mpath); } - - rhashtable_walk_stop(&iter); - rhashtable_walk_exit(&iter); + spin_unlock_bh(&tbl->walk_lock); } void mesh_path_expire(struct ieee80211_sub_if_data *sdata) diff --git a/net/mac80211/mesh_plink.c b/net/mac80211/mesh_plink.c index 33055c8ed37e..8afd0ece94c9 100644 --- a/net/mac80211/mesh_plink.c +++ b/net/mac80211/mesh_plink.c @@ -1,5 +1,6 @@ /* * Copyright (c) 2008, 2009 open80211s Ltd. + * Copyright (C) 2019 Intel Corporation * Author: Luis Carlos Cobo <luisca@cozybit.com> * * This program is free software; you can redistribute it and/or modify @@ -1214,6 +1215,7 @@ void mesh_rx_plink_frame(struct ieee80211_sub_if_data *sdata, if (baselen > len) return; } - ieee802_11_parse_elems(baseaddr, len - baselen, true, &elems); + ieee802_11_parse_elems(baseaddr, len - baselen, true, &elems, + mgmt->bssid, NULL); mesh_process_plink_frame(sdata, mgmt, &elems, rx_status); } diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c index 687821567287..2dbcf5d5512e 100644 --- a/net/mac80211/mlme.c +++ b/net/mac80211/mlme.c @@ -7,7 +7,7 @@ * Copyright 2007, Michael Wu <flamingice@sourmilk.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright (C) 2015 - 2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -644,7 +644,7 @@ static void ieee80211_send_assoc(struct ieee80211_sub_if_data *sdata) struct ieee80211_mgd_assoc_data *assoc_data = ifmgd->assoc_data; struct sk_buff *skb; struct ieee80211_mgmt *mgmt; - u8 *pos, qos_info; + u8 *pos, qos_info, *ie_start; size_t offset = 0, noffset; int i, count, rates_len, supp_rates_len, shift; u16 capab; @@ -752,6 +752,7 @@ static void ieee80211_send_assoc(struct ieee80211_sub_if_data *sdata) /* SSID */ pos = skb_put(skb, 2 + assoc_data->ssid_len); + ie_start = pos; *pos++ = WLAN_EID_SSID; *pos++ = assoc_data->ssid_len; memcpy(pos, assoc_data->ssid, assoc_data->ssid_len); @@ -813,6 +814,21 @@ static void ieee80211_send_assoc(struct ieee80211_sub_if_data *sdata) } } + /* Set MBSSID support for HE AP if needed */ + if (ieee80211_hw_check(&local->hw, SUPPORTS_ONLY_HE_MULTI_BSSID) && + !(ifmgd->flags & IEEE80211_STA_DISABLE_HE) && assoc_data->ie_len) { + struct element *elem; + + /* we know it's writable, cast away the const */ + elem = (void *)cfg80211_find_elem(WLAN_EID_EXT_CAPABILITY, + assoc_data->ie, + assoc_data->ie_len); + + /* We can probably assume both always true */ + if (elem && elem->datalen >= 3) + elem->data[2] |= WLAN_EXT_CAPA3_MULTI_BSSID_SUPPORT; + } + /* if present, add any custom IEs that go before HT */ if (assoc_data->ie_len) { static const u8 before_ht[] = { @@ -961,6 +977,11 @@ static void ieee80211_send_assoc(struct ieee80211_sub_if_data *sdata) return; } + pos = skb_tail_pointer(skb); + kfree(ifmgd->assoc_req_ies); + ifmgd->assoc_req_ies = kmemdup(ie_start, pos - ie_start, GFP_ATOMIC); + ifmgd->assoc_req_ies_len = pos - ie_start; + drv_mgd_prepare_tx(local, sdata, 0); IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; @@ -1238,6 +1259,32 @@ static void ieee80211_chswitch_timer(struct timer_list *t) } static void +ieee80211_sta_abort_chanswitch(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + + if (!local->ops->abort_channel_switch) + return; + + mutex_lock(&local->mtx); + + mutex_lock(&local->chanctx_mtx); + ieee80211_vif_unreserve_chanctx(sdata); + mutex_unlock(&local->chanctx_mtx); + + if (sdata->csa_block_tx) + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + + sdata->csa_block_tx = false; + sdata->vif.csa_active = false; + + mutex_unlock(&local->mtx); + + drv_abort_channel_switch(sdata); +} + +static void ieee80211_sta_process_chanswitch(struct ieee80211_sub_if_data *sdata, u64 timestamp, u32 device_timestamp, struct ieee802_11_elems *elems, @@ -1261,19 +1308,36 @@ ieee80211_sta_process_chanswitch(struct ieee80211_sub_if_data *sdata, if (local->scanning) return; - /* disregard subsequent announcements if we are already processing */ - if (sdata->vif.csa_active) - return; - current_band = cbss->channel->band; res = ieee80211_parse_ch_switch_ie(sdata, elems, current_band, ifmgd->flags, ifmgd->associated->bssid, &csa_ie); - if (res < 0) + + if (!res) { + ch_switch.timestamp = timestamp; + ch_switch.device_timestamp = device_timestamp; + ch_switch.block_tx = beacon ? csa_ie.mode : 0; + ch_switch.chandef = csa_ie.chandef; + ch_switch.count = csa_ie.count; + ch_switch.delay = csa_ie.max_switch_time; + } + + if (res < 0) { ieee80211_queue_work(&local->hw, &ifmgd->csa_connection_drop_work); - if (res) return; + } + + if (beacon && sdata->vif.csa_active && !ifmgd->csa_waiting_bcn) { + if (res) + ieee80211_sta_abort_chanswitch(sdata); + else + drv_channel_switch_rx_beacon(sdata, &ch_switch); + return; + } else if (sdata->vif.csa_active || res) { + /* disregard subsequent announcements if already processing */ + return; + } if (!cfg80211_chandef_usable(local->hw.wiphy, &csa_ie.chandef, IEEE80211_CHAN_DISABLED)) { @@ -1289,7 +1353,8 @@ ieee80211_sta_process_chanswitch(struct ieee80211_sub_if_data *sdata, } if (cfg80211_chandef_identical(&csa_ie.chandef, - &sdata->vif.bss_conf.chandef)) { + &sdata->vif.bss_conf.chandef) && + (!csa_ie.mode || !beacon)) { if (ifmgd->csa_ignored_same_chan) return; sdata_info(sdata, @@ -1326,12 +1391,6 @@ ieee80211_sta_process_chanswitch(struct ieee80211_sub_if_data *sdata, goto drop_connection; } - ch_switch.timestamp = timestamp; - ch_switch.device_timestamp = device_timestamp; - ch_switch.block_tx = csa_ie.mode; - ch_switch.chandef = csa_ie.chandef; - ch_switch.count = csa_ie.count; - if (drv_pre_channel_switch(sdata, &ch_switch)) { sdata_info(sdata, "preparing for channel switch failed, disconnecting\n"); @@ -1350,7 +1409,7 @@ ieee80211_sta_process_chanswitch(struct ieee80211_sub_if_data *sdata, sdata->vif.csa_active = true; sdata->csa_chandef = csa_ie.chandef; - sdata->csa_block_tx = csa_ie.mode; + sdata->csa_block_tx = ch_switch.block_tx; ifmgd->csa_ignored_same_chan = false; if (sdata->csa_block_tx) @@ -1384,7 +1443,7 @@ ieee80211_sta_process_chanswitch(struct ieee80211_sub_if_data *sdata, * reset when the disconnection worker runs. */ sdata->vif.csa_active = true; - sdata->csa_block_tx = csa_ie.mode; + sdata->csa_block_tx = ch_switch.block_tx; ieee80211_queue_work(&local->hw, &ifmgd->csa_connection_drop_work); mutex_unlock(&local->chanctx_mtx); @@ -2762,7 +2821,8 @@ static void ieee80211_auth_challenge(struct ieee80211_sub_if_data *sdata, u32 tx_flags = 0; pos = mgmt->u.auth.variable; - ieee802_11_parse_elems(pos, len - (pos - (u8 *) mgmt), false, &elems); + ieee802_11_parse_elems(pos, len - (pos - (u8 *)mgmt), false, &elems, + mgmt->bssid, auth_data->bss->bssid); if (!elems.challenge) return; auth_data->expected_transaction = 4; @@ -3130,7 +3190,8 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, } pos = mgmt->u.assoc_resp.variable; - ieee802_11_parse_elems(pos, len - (pos - (u8 *) mgmt), false, &elems); + ieee802_11_parse_elems(pos, len - (pos - (u8 *)mgmt), false, &elems, + mgmt->bssid, assoc_data->bss->bssid); if (!elems.supp_rates) { sdata_info(sdata, "no SuppRates element in AssocResp\n"); @@ -3167,7 +3228,9 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, return false; ieee802_11_parse_elems(bss_ies->data, bss_ies->len, - false, &bss_elems); + false, &bss_elems, + mgmt->bssid, + assoc_data->bss->bssid); if (assoc_data->wmm && !elems.wmm_param && bss_elems.wmm_param) { elems.wmm_param = bss_elems.wmm_param; @@ -3304,6 +3367,14 @@ static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, /* TODO: OPEN: what happens if BSS color disable is set? */ } + if (cbss->transmitted_bss) { + bss_conf->nontransmitted = true; + ether_addr_copy(bss_conf->transmitter_bssid, + cbss->transmitted_bss->bssid); + bss_conf->bssid_indicator = cbss->max_bssid_indicator; + bss_conf->bssid_index = cbss->bssid_index; + } + /* * Some APs, e.g. Netgear WNDR3700, report invalid HT operation data * in their association response, so ignore that data for our own @@ -3464,7 +3535,8 @@ static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, return; pos = mgmt->u.assoc_resp.variable; - ieee802_11_parse_elems(pos, len - (pos - (u8 *) mgmt), false, &elems); + ieee802_11_parse_elems(pos, len - (pos - (u8 *)mgmt), false, &elems, + mgmt->bssid, assoc_data->bss->bssid); if (status_code == WLAN_STATUS_ASSOC_REJECTED_TEMPORARILY && elems.timeout_int && @@ -3516,13 +3588,13 @@ static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, uapsd_queues |= ieee80211_ac_to_qos_mask[ac]; } - cfg80211_rx_assoc_resp(sdata->dev, bss, (u8 *)mgmt, len, uapsd_queues); + cfg80211_rx_assoc_resp(sdata->dev, bss, (u8 *)mgmt, len, uapsd_queues, + ifmgd->assoc_req_ies, ifmgd->assoc_req_ies_len); } static void ieee80211_rx_bss_info(struct ieee80211_sub_if_data *sdata, struct ieee80211_mgmt *mgmt, size_t len, - struct ieee80211_rx_status *rx_status, - struct ieee802_11_elems *elems) + struct ieee80211_rx_status *rx_status) { struct ieee80211_local *local = sdata->local; struct ieee80211_bss *bss; @@ -3534,8 +3606,7 @@ static void ieee80211_rx_bss_info(struct ieee80211_sub_if_data *sdata, if (!channel) return; - bss = ieee80211_bss_info_update(local, rx_status, mgmt, len, elems, - channel); + bss = ieee80211_bss_info_update(local, rx_status, mgmt, len, channel); if (bss) { sdata->vif.bss_conf.beacon_rate = bss->beacon_rate; ieee80211_rx_bss_put(local, bss); @@ -3550,7 +3621,6 @@ static void ieee80211_rx_mgmt_probe_resp(struct ieee80211_sub_if_data *sdata, struct ieee80211_if_managed *ifmgd; struct ieee80211_rx_status *rx_status = (void *) skb->cb; size_t baselen, len = skb->len; - struct ieee802_11_elems elems; ifmgd = &sdata->u.mgd; @@ -3563,10 +3633,7 @@ static void ieee80211_rx_mgmt_probe_resp(struct ieee80211_sub_if_data *sdata, if (baselen > len) return; - ieee802_11_parse_elems(mgmt->u.probe_resp.variable, len - baselen, - false, &elems); - - ieee80211_rx_bss_info(sdata, mgmt, len, rx_status, &elems); + ieee80211_rx_bss_info(sdata, mgmt, len, rx_status); if (ifmgd->associated && ether_addr_equal(mgmt->bssid, ifmgd->associated->bssid)) @@ -3693,6 +3760,16 @@ static void ieee80211_handle_beacon_sig(struct ieee80211_sub_if_data *sdata, } } +static bool ieee80211_rx_our_beacon(const u8 *tx_bssid, + struct cfg80211_bss *bss) +{ + if (ether_addr_equal(tx_bssid, bss->bssid)) + return true; + if (!bss->transmitted_bss) + return false; + return ether_addr_equal(tx_bssid, bss->transmitted_bss->bssid); +} + static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, struct ieee80211_mgmt *mgmt, size_t len, struct ieee80211_rx_status *rx_status) @@ -3734,15 +3811,16 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, rcu_read_unlock(); if (ifmgd->assoc_data && ifmgd->assoc_data->need_beacon && - ether_addr_equal(mgmt->bssid, ifmgd->assoc_data->bss->bssid)) { + ieee80211_rx_our_beacon(mgmt->bssid, ifmgd->assoc_data->bss)) { ieee802_11_parse_elems(mgmt->u.beacon.variable, - len - baselen, false, &elems); + len - baselen, false, &elems, + mgmt->bssid, + ifmgd->assoc_data->bss->bssid); - ieee80211_rx_bss_info(sdata, mgmt, len, rx_status, &elems); - if (elems.tim && !elems.parse_error) { - const struct ieee80211_tim_ie *tim_ie = elems.tim; - ifmgd->dtim_period = tim_ie->dtim_period; - } + ieee80211_rx_bss_info(sdata, mgmt, len, rx_status); + + if (elems.dtim_period) + ifmgd->dtim_period = elems.dtim_period; ifmgd->have_beacon = true; ifmgd->assoc_data->need_beacon = false; if (ieee80211_hw_check(&local->hw, TIMING_BEACON_ONLY)) { @@ -3750,12 +3828,17 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, le64_to_cpu(mgmt->u.beacon.timestamp); sdata->vif.bss_conf.sync_device_ts = rx_status->device_timestamp; - if (elems.tim) - sdata->vif.bss_conf.sync_dtim_count = - elems.tim->dtim_count; - else - sdata->vif.bss_conf.sync_dtim_count = 0; + sdata->vif.bss_conf.sync_dtim_count = elems.dtim_count; } + + if (elems.mbssid_config_ie) + bss_conf->profile_periodicity = + elems.mbssid_config_ie->profile_periodicity; + + if (elems.ext_capab_len >= 11 && + (elems.ext_capab[10] & WLAN_EXT_CAPA11_EMA_SUPPORT)) + bss_conf->ema_ap = true; + /* continue assoc process */ ifmgd->assoc_data->timeout = jiffies; ifmgd->assoc_data->timeout_started = true; @@ -3764,7 +3847,7 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, } if (!ifmgd->associated || - !ether_addr_equal(mgmt->bssid, ifmgd->associated->bssid)) + !ieee80211_rx_our_beacon(mgmt->bssid, ifmgd->associated)) return; bssid = ifmgd->associated->bssid; @@ -3787,7 +3870,8 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, ncrc = crc32_be(0, (void *)&mgmt->u.beacon.beacon_int, 4); ncrc = ieee802_11_parse_elems_crc(mgmt->u.beacon.variable, len - baselen, false, &elems, - care_about_ies, ncrc); + care_about_ies, ncrc, + mgmt->bssid, bssid); if (ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK) && ieee80211_check_tim(elems.tim, elems.tim_len, ifmgd->aid)) { @@ -3859,11 +3943,7 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, le64_to_cpu(mgmt->u.beacon.timestamp); sdata->vif.bss_conf.sync_device_ts = rx_status->device_timestamp; - if (elems.tim) - sdata->vif.bss_conf.sync_dtim_count = - elems.tim->dtim_count; - else - sdata->vif.bss_conf.sync_dtim_count = 0; + sdata->vif.bss_conf.sync_dtim_count = elems.dtim_count; } if (ncrc == ifmgd->beacon_crc && ifmgd->beacon_crc_valid) @@ -3871,7 +3951,7 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, ifmgd->beacon_crc = ncrc; ifmgd->beacon_crc_valid = true; - ieee80211_rx_bss_info(sdata, mgmt, len, rx_status, &elems); + ieee80211_rx_bss_info(sdata, mgmt, len, rx_status); ieee80211_sta_process_chanswitch(sdata, rx_status->mactime, rx_status->device_timestamp, @@ -3889,10 +3969,7 @@ static void ieee80211_rx_mgmt_beacon(struct ieee80211_sub_if_data *sdata, */ if (!ifmgd->have_beacon) { /* a few bogus AP send dtim_period = 0 or no TIM IE */ - if (elems.tim) - bss_conf->dtim_period = elems.tim->dtim_period ?: 1; - else - bss_conf->dtim_period = 1; + bss_conf->dtim_period = elems.dtim_period ?: 1; changed |= BSS_CHANGED_BEACON_INFO; ifmgd->have_beacon = true; @@ -3992,9 +4069,10 @@ void ieee80211_sta_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, if (ies_len < 0) break; + /* CSA IE cannot be overridden, no need for BSSID */ ieee802_11_parse_elems( mgmt->u.action.u.chan_switch.variable, - ies_len, true, &elems); + ies_len, true, &elems, mgmt->bssid, NULL); if (elems.parse_error) break; @@ -4011,9 +4089,13 @@ void ieee80211_sta_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, if (ies_len < 0) break; + /* + * extended CSA IE can't be overridden, no need for + * BSSID + */ ieee802_11_parse_elems( mgmt->u.action.u.ext_chan_switch.variable, - ies_len, true, &elems); + ies_len, true, &elems, mgmt->bssid, NULL); if (elems.parse_error) break; @@ -4754,6 +4836,40 @@ static int ieee80211_prep_channel(struct ieee80211_sub_if_data *sdata, return ret; } +static bool ieee80211_get_dtim(const struct cfg80211_bss_ies *ies, + u8 *dtim_count, u8 *dtim_period) +{ + const u8 *tim_ie = cfg80211_find_ie(WLAN_EID_TIM, ies->data, ies->len); + const u8 *idx_ie = cfg80211_find_ie(WLAN_EID_MULTI_BSSID_IDX, ies->data, + ies->len); + const struct ieee80211_tim_ie *tim = NULL; + const struct ieee80211_bssid_index *idx; + bool valid = tim_ie && tim_ie[1] >= 2; + + if (valid) + tim = (void *)(tim_ie + 2); + + if (dtim_count) + *dtim_count = valid ? tim->dtim_count : 0; + + if (dtim_period) + *dtim_period = valid ? tim->dtim_period : 0; + + /* Check if value is overridden by non-transmitted profile */ + if (!idx_ie || idx_ie[1] < 3) + return valid; + + idx = (void *)(idx_ie + 2); + + if (dtim_count) + *dtim_count = idx->dtim_count; + + if (dtim_period) + *dtim_period = idx->dtim_period; + + return true; +} + static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata, struct cfg80211_bss *cbss, bool assoc, bool override) @@ -4845,17 +4961,13 @@ static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata, rcu_read_lock(); ies = rcu_dereference(cbss->beacon_ies); if (ies) { - const u8 *tim_ie; - sdata->vif.bss_conf.sync_tsf = ies->tsf; sdata->vif.bss_conf.sync_device_ts = bss->device_ts_beacon; - tim_ie = cfg80211_find_ie(WLAN_EID_TIM, - ies->data, ies->len); - if (tim_ie && tim_ie[1] >= 2) - sdata->vif.bss_conf.sync_dtim_count = tim_ie[2]; - else - sdata->vif.bss_conf.sync_dtim_count = 0; + + ieee80211_get_dtim(ies, + &sdata->vif.bss_conf.sync_dtim_count, + NULL); } else if (!ieee80211_hw_check(&sdata->local->hw, TIMING_BEACON_ONLY)) { ies = rcu_dereference(cbss->proberesp_ies); @@ -5325,17 +5437,12 @@ int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata, assoc_data->timeout_started = true; assoc_data->need_beacon = true; } else if (beacon_ies) { - const u8 *tim_ie = cfg80211_find_ie(WLAN_EID_TIM, - beacon_ies->data, - beacon_ies->len); + const u8 *ie; u8 dtim_count = 0; - if (tim_ie && tim_ie[1] >= sizeof(struct ieee80211_tim_ie)) { - const struct ieee80211_tim_ie *tim; - tim = (void *)(tim_ie + 2); - ifmgd->dtim_period = tim->dtim_period; - dtim_count = tim->dtim_count; - } + ieee80211_get_dtim(beacon_ies, &dtim_count, + &ifmgd->dtim_period); + ifmgd->have_beacon = true; assoc_data->timeout = jiffies; assoc_data->timeout_started = true; @@ -5346,6 +5453,17 @@ int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata, bss->device_ts_beacon; sdata->vif.bss_conf.sync_dtim_count = dtim_count; } + + ie = cfg80211_find_ext_ie(WLAN_EID_EXT_MULTIPLE_BSSID_CONFIGURATION, + beacon_ies->data, beacon_ies->len); + if (ie && ie[1] >= 3) + sdata->vif.bss_conf.profile_periodicity = ie[4]; + + ie = cfg80211_find_ie(WLAN_EID_EXT_CAPABILITY, + beacon_ies->data, beacon_ies->len); + if (ie && ie[1] >= 11 && + (ie[10] & WLAN_EXT_CAPA11_EMA_SUPPORT)) + sdata->vif.bss_conf.ema_ap = true; } else { assoc_data->timeout = jiffies; assoc_data->timeout_started = true; @@ -5503,6 +5621,9 @@ void ieee80211_mgd_stop(struct ieee80211_sub_if_data *sdata) ifmgd->teardown_skb = NULL; ifmgd->orig_teardown_skb = NULL; } + kfree(ifmgd->assoc_req_ies); + ifmgd->assoc_req_ies = NULL; + ifmgd->assoc_req_ies_len = 0; spin_unlock_bh(&ifmgd->teardown_lock); del_timer_sync(&ifmgd->timer); sdata_unlock(sdata); diff --git a/net/mac80211/rc80211_minstrel_ht.c b/net/mac80211/rc80211_minstrel_ht.c index f466ec37d161..ccaf951e4e31 100644 --- a/net/mac80211/rc80211_minstrel_ht.c +++ b/net/mac80211/rc80211_minstrel_ht.c @@ -294,6 +294,15 @@ minstrel_get_ratestats(struct minstrel_ht_sta *mi, int index) return &mi->groups[index / MCS_GROUP_RATES].rates[index % MCS_GROUP_RATES]; } +static unsigned int +minstrel_ht_avg_ampdu_len(struct minstrel_ht_sta *mi) +{ + if (!mi->avg_ampdu_len) + return AVG_AMPDU_SIZE; + + return MINSTREL_TRUNC(mi->avg_ampdu_len); +} + /* * Return current throughput based on the average A-MPDU length, taking into * account the expected number of retransmissions and their expected length @@ -309,7 +318,7 @@ minstrel_ht_get_tp_avg(struct minstrel_ht_sta *mi, int group, int rate, return 0; if (group != MINSTREL_CCK_GROUP) - nsecs = 1000 * mi->overhead / MINSTREL_TRUNC(mi->avg_ampdu_len); + nsecs = 1000 * mi->overhead / minstrel_ht_avg_ampdu_len(mi); nsecs += minstrel_mcs_groups[group].duration[rate] << minstrel_mcs_groups[group].shift; @@ -503,8 +512,12 @@ minstrel_ht_update_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) u16 tmp_cck_tp_rate[MAX_THR_RATES], index; if (mi->ampdu_packets > 0) { - mi->avg_ampdu_len = minstrel_ewma(mi->avg_ampdu_len, - MINSTREL_FRAC(mi->ampdu_len, mi->ampdu_packets), EWMA_LEVEL); + if (!ieee80211_hw_check(mp->hw, TX_STATUS_NO_AMPDU_LEN)) + mi->avg_ampdu_len = minstrel_ewma(mi->avg_ampdu_len, + MINSTREL_FRAC(mi->ampdu_len, mi->ampdu_packets), + EWMA_LEVEL); + else + mi->avg_ampdu_len = 0; mi->ampdu_len = 0; mi->ampdu_packets = 0; } @@ -709,7 +722,9 @@ minstrel_ht_tx_status(void *priv, struct ieee80211_supported_band *sband, mi->ampdu_len += info->status.ampdu_len; if (!mi->sample_wait && !mi->sample_tries && mi->sample_count > 0) { - mi->sample_wait = 16 + 2 * MINSTREL_TRUNC(mi->avg_ampdu_len); + int avg_ampdu_len = minstrel_ht_avg_ampdu_len(mi); + + mi->sample_wait = 16 + 2 * avg_ampdu_len; mi->sample_tries = 1; mi->sample_count--; } @@ -777,7 +792,7 @@ minstrel_calc_retransmit(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, unsigned int cw = mp->cw_min; unsigned int ctime = 0; unsigned int t_slot = 9; /* FIXME */ - unsigned int ampdu_len = MINSTREL_TRUNC(mi->avg_ampdu_len); + unsigned int ampdu_len = minstrel_ht_avg_ampdu_len(mi); unsigned int overhead = 0, overhead_rtscts = 0; mrs = minstrel_get_ratestats(mi, index); diff --git a/net/mac80211/rc80211_minstrel_ht_debugfs.c b/net/mac80211/rc80211_minstrel_ht_debugfs.c index 57820a5f2c16..31641d0b0f5c 100644 --- a/net/mac80211/rc80211_minstrel_ht_debugfs.c +++ b/net/mac80211/rc80211_minstrel_ht_debugfs.c @@ -160,9 +160,10 @@ minstrel_ht_stats_open(struct inode *inode, struct file *file) "lookaround %d\n", max(0, (int) mi->total_packets - (int) mi->sample_packets), mi->sample_packets); - p += sprintf(p, "Average # of aggregated frames per A-MPDU: %d.%d\n", - MINSTREL_TRUNC(mi->avg_ampdu_len), - MINSTREL_TRUNC(mi->avg_ampdu_len * 10) % 10); + if (mi->avg_ampdu_len) + p += sprintf(p, "Average # of aggregated frames per A-MPDU: %d.%d\n", + MINSTREL_TRUNC(mi->avg_ampdu_len), + MINSTREL_TRUNC(mi->avg_ampdu_len * 10) % 10); ms->len = p - ms->buf; WARN_ON(ms->len + sizeof(*ms) > 32768); diff --git a/net/mac80211/rx.c b/net/mac80211/rx.c index 45aad3d3108c..7f8d93401ce0 100644 --- a/net/mac80211/rx.c +++ b/net/mac80211/rx.c @@ -5,7 +5,7 @@ * Copyright 2007-2010 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright(c) 2015 - 2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018-2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -208,7 +208,24 @@ ieee80211_rx_radiotap_hdrlen(struct ieee80211_local *local, } if (status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA) { - struct ieee80211_vendor_radiotap *rtap = (void *)skb->data; + struct ieee80211_vendor_radiotap *rtap; + int vendor_data_offset = 0; + + /* + * The position to look at depends on the existence (or non- + * existence) of other elements, so take that into account... + */ + if (status->flag & RX_FLAG_RADIOTAP_HE) + vendor_data_offset += + sizeof(struct ieee80211_radiotap_he); + if (status->flag & RX_FLAG_RADIOTAP_HE_MU) + vendor_data_offset += + sizeof(struct ieee80211_radiotap_he_mu); + if (status->flag & RX_FLAG_RADIOTAP_LSIG) + vendor_data_offset += + sizeof(struct ieee80211_radiotap_lsig); + + rtap = (void *)&skb->data[vendor_data_offset]; /* alignment for fixed 6-byte vendor data header */ len = ALIGN(len, 2); @@ -231,7 +248,7 @@ static void ieee80211_handle_mu_mimo_mon(struct ieee80211_sub_if_data *sdata, struct ieee80211_hdr_3addr hdr; u8 category; u8 action_code; - } __packed action; + } __packed __aligned(2) action; if (!sdata) return; @@ -2644,6 +2661,7 @@ ieee80211_rx_h_mesh_fwding(struct ieee80211_rx_data *rx) struct ieee80211_sub_if_data *sdata = rx->sdata; struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; u16 ac, q, hdrlen; + int tailroom = 0; hdr = (struct ieee80211_hdr *) skb->data; hdrlen = ieee80211_hdrlen(hdr->frame_control); @@ -2723,15 +2741,21 @@ ieee80211_rx_h_mesh_fwding(struct ieee80211_rx_data *rx) skb_set_queue_mapping(skb, q); if (!--mesh_hdr->ttl) { - IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, dropped_frames_ttl); + if (!is_multicast_ether_addr(hdr->addr1)) + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, + dropped_frames_ttl); goto out; } if (!ifmsh->mshcfg.dot11MeshForwarding) goto out; + if (sdata->crypto_tx_tailroom_needed_cnt) + tailroom = IEEE80211_ENCRYPT_TAILROOM; + fwd_skb = skb_copy_expand(skb, local->tx_headroom + - sdata->encrypt_headroom, 0, GFP_ATOMIC); + sdata->encrypt_headroom, + tailroom, GFP_ATOMIC); if (!fwd_skb) goto out; diff --git a/net/mac80211/scan.c b/net/mac80211/scan.c index 95413413f98c..0cf066700623 100644 --- a/net/mac80211/scan.c +++ b/net/mac80211/scan.c @@ -8,6 +8,7 @@ * Copyright 2007, Michael Wu <flamingice@sourmilk.net> * Copyright 2013-2015 Intel Mobile Communications GmbH * Copyright 2016-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -57,62 +58,14 @@ static bool is_uapsd_supported(struct ieee802_11_elems *elems) return qos_info & IEEE80211_WMM_IE_AP_QOSINFO_UAPSD; } -struct ieee80211_bss * -ieee80211_bss_info_update(struct ieee80211_local *local, - struct ieee80211_rx_status *rx_status, - struct ieee80211_mgmt *mgmt, size_t len, - struct ieee802_11_elems *elems, - struct ieee80211_channel *channel) +static void +ieee80211_update_bss_from_elems(struct ieee80211_local *local, + struct ieee80211_bss *bss, + struct ieee802_11_elems *elems, + struct ieee80211_rx_status *rx_status, + bool beacon) { - bool beacon = ieee80211_is_beacon(mgmt->frame_control); - struct cfg80211_bss *cbss; - struct ieee80211_bss *bss; int clen, srlen; - struct cfg80211_inform_bss bss_meta = { - .boottime_ns = rx_status->boottime_ns, - }; - bool signal_valid; - struct ieee80211_sub_if_data *scan_sdata; - - if (rx_status->flag & RX_FLAG_NO_SIGNAL_VAL) - bss_meta.signal = 0; /* invalid signal indication */ - else if (ieee80211_hw_check(&local->hw, SIGNAL_DBM)) - bss_meta.signal = rx_status->signal * 100; - else if (ieee80211_hw_check(&local->hw, SIGNAL_UNSPEC)) - bss_meta.signal = (rx_status->signal * 100) / local->hw.max_signal; - - bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_20; - if (rx_status->bw == RATE_INFO_BW_5) - bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_5; - else if (rx_status->bw == RATE_INFO_BW_10) - bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_10; - - bss_meta.chan = channel; - - rcu_read_lock(); - scan_sdata = rcu_dereference(local->scan_sdata); - if (scan_sdata && scan_sdata->vif.type == NL80211_IFTYPE_STATION && - scan_sdata->vif.bss_conf.assoc && - ieee80211_have_rx_timestamp(rx_status)) { - bss_meta.parent_tsf = - ieee80211_calculate_rx_timestamp(local, rx_status, - len + FCS_LEN, 24); - ether_addr_copy(bss_meta.parent_bssid, - scan_sdata->vif.bss_conf.bssid); - } - rcu_read_unlock(); - - cbss = cfg80211_inform_bss_frame_data(local->hw.wiphy, &bss_meta, - mgmt, len, GFP_ATOMIC); - if (!cbss) - return NULL; - /* In case the signal is invalid update the status */ - signal_valid = abs(channel->center_freq - cbss->channel->center_freq) - <= local->hw.wiphy->max_adj_channel_rssi_comp; - if (!signal_valid) - rx_status->flag |= RX_FLAG_NO_SIGNAL_VAL; - - bss = (void *)cbss->priv; if (beacon) bss->device_ts_beacon = rx_status->device_timestamp; @@ -182,6 +135,89 @@ ieee80211_bss_info_update(struct ieee80211_local *local, bss->beacon_rate = &sband->bitrates[rx_status->rate_idx]; } +} + +struct ieee80211_bss * +ieee80211_bss_info_update(struct ieee80211_local *local, + struct ieee80211_rx_status *rx_status, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_channel *channel) +{ + bool beacon = ieee80211_is_beacon(mgmt->frame_control); + struct cfg80211_bss *cbss, *non_tx_cbss; + struct ieee80211_bss *bss, *non_tx_bss; + struct cfg80211_inform_bss bss_meta = { + .boottime_ns = rx_status->boottime_ns, + }; + bool signal_valid; + struct ieee80211_sub_if_data *scan_sdata; + struct ieee802_11_elems elems; + size_t baselen; + u8 *elements; + + if (rx_status->flag & RX_FLAG_NO_SIGNAL_VAL) + bss_meta.signal = 0; /* invalid signal indication */ + else if (ieee80211_hw_check(&local->hw, SIGNAL_DBM)) + bss_meta.signal = rx_status->signal * 100; + else if (ieee80211_hw_check(&local->hw, SIGNAL_UNSPEC)) + bss_meta.signal = (rx_status->signal * 100) / local->hw.max_signal; + + bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_20; + if (rx_status->bw == RATE_INFO_BW_5) + bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_5; + else if (rx_status->bw == RATE_INFO_BW_10) + bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_10; + + bss_meta.chan = channel; + + rcu_read_lock(); + scan_sdata = rcu_dereference(local->scan_sdata); + if (scan_sdata && scan_sdata->vif.type == NL80211_IFTYPE_STATION && + scan_sdata->vif.bss_conf.assoc && + ieee80211_have_rx_timestamp(rx_status)) { + bss_meta.parent_tsf = + ieee80211_calculate_rx_timestamp(local, rx_status, + len + FCS_LEN, 24); + ether_addr_copy(bss_meta.parent_bssid, + scan_sdata->vif.bss_conf.bssid); + } + rcu_read_unlock(); + + cbss = cfg80211_inform_bss_frame_data(local->hw.wiphy, &bss_meta, + mgmt, len, GFP_ATOMIC); + if (!cbss) + return NULL; + + if (ieee80211_is_probe_resp(mgmt->frame_control)) { + elements = mgmt->u.probe_resp.variable; + baselen = offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + } else { + baselen = offsetof(struct ieee80211_mgmt, u.beacon.variable); + elements = mgmt->u.beacon.variable; + } + + if (baselen > len) + return NULL; + + ieee802_11_parse_elems(elements, len - baselen, false, &elems, + mgmt->bssid, cbss->bssid); + + /* In case the signal is invalid update the status */ + signal_valid = abs(channel->center_freq - cbss->channel->center_freq) + <= local->hw.wiphy->max_adj_channel_rssi_comp; + if (!signal_valid) + rx_status->flag |= RX_FLAG_NO_SIGNAL_VAL; + + bss = (void *)cbss->priv; + ieee80211_update_bss_from_elems(local, bss, &elems, rx_status, beacon); + + list_for_each_entry(non_tx_cbss, &cbss->nontrans_list, nontrans_list) { + non_tx_bss = (void *)non_tx_cbss->priv; + + ieee80211_update_bss_from_elems(local, non_tx_bss, &elems, + rx_status, beacon); + } return bss; } @@ -206,10 +242,7 @@ void ieee80211_scan_rx(struct ieee80211_local *local, struct sk_buff *skb) struct ieee80211_sub_if_data *sdata1, *sdata2; struct ieee80211_mgmt *mgmt = (void *)skb->data; struct ieee80211_bss *bss; - u8 *elements; struct ieee80211_channel *channel; - size_t baselen; - struct ieee802_11_elems elems; if (skb->len < 24 || (!ieee80211_is_probe_resp(mgmt->frame_control) && @@ -244,26 +277,15 @@ void ieee80211_scan_rx(struct ieee80211_local *local, struct sk_buff *skb) !ieee80211_scan_accept_presp(sdata2, sched_scan_req_flags, mgmt->da)) return; - - elements = mgmt->u.probe_resp.variable; - baselen = offsetof(struct ieee80211_mgmt, u.probe_resp.variable); - } else { - baselen = offsetof(struct ieee80211_mgmt, u.beacon.variable); - elements = mgmt->u.beacon.variable; } - if (baselen > skb->len) - return; - - ieee802_11_parse_elems(elements, skb->len - baselen, false, &elems); - channel = ieee80211_get_channel(local->hw.wiphy, rx_status->freq); if (!channel || channel->flags & IEEE80211_CHAN_DISABLED) return; bss = ieee80211_bss_info_update(local, rx_status, - mgmt, skb->len, &elems, + mgmt, skb->len, channel); if (bss) ieee80211_rx_bss_put(local, bss); diff --git a/net/mac80211/spectmgmt.c b/net/mac80211/spectmgmt.c index 4e4902bdbef8..3c644f14dd59 100644 --- a/net/mac80211/spectmgmt.c +++ b/net/mac80211/spectmgmt.c @@ -177,6 +177,12 @@ int ieee80211_parse_ch_switch_ie(struct ieee80211_sub_if_data *sdata, csa_ie->chandef = new_vht_chandef; } + if (elems->max_channel_switch_time) + csa_ie->max_switch_time = + (elems->max_channel_switch_time[0] << 0) | + (elems->max_channel_switch_time[1] << 8) | + (elems->max_channel_switch_time[2] << 16); + return 0; } diff --git a/net/mac80211/sta_info.c b/net/mac80211/sta_info.c index c4a8f115ed33..11f058987a54 100644 --- a/net/mac80211/sta_info.c +++ b/net/mac80211/sta_info.c @@ -90,7 +90,6 @@ static void __cleanup_single_sta(struct sta_info *sta) struct tid_ampdu_tx *tid_tx; struct ieee80211_sub_if_data *sdata = sta->sdata; struct ieee80211_local *local = sdata->local; - struct fq *fq = &local->fq; struct ps_data *ps; if (test_sta_flag(sta, WLAN_STA_PS_STA) || @@ -120,9 +119,7 @@ static void __cleanup_single_sta(struct sta_info *sta) txqi = to_txq_info(sta->sta.txq[i]); - spin_lock_bh(&fq->lock); ieee80211_txq_purge(local, txqi); - spin_unlock_bh(&fq->lock); } } @@ -387,9 +384,12 @@ struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata, if (sta_prepare_rate_control(local, sta, gfp)) goto free_txq; + sta->airtime_weight = IEEE80211_DEFAULT_AIRTIME_WEIGHT; + for (i = 0; i < IEEE80211_NUM_ACS; i++) { skb_queue_head_init(&sta->ps_tx_buf[i]); skb_queue_head_init(&sta->tx_filtered[i]); + sta->airtime[i].deficit = sta->airtime_weight; } for (i = 0; i < IEEE80211_NUM_TIDS; i++) @@ -1249,7 +1249,7 @@ void ieee80211_sta_ps_deliver_wakeup(struct sta_info *sta) if (!sta->sta.txq[i] || !txq_has_queue(sta->sta.txq[i])) continue; - drv_wake_tx_queue(local, to_txq_info(sta->sta.txq[i])); + schedule_and_wake_txq(local, to_txq_info(sta->sta.txq[i])); } skb_queue_head_init(&pending); @@ -1826,6 +1826,27 @@ void ieee80211_sta_set_buffered(struct ieee80211_sta *pubsta, } EXPORT_SYMBOL(ieee80211_sta_set_buffered); +void ieee80211_sta_register_airtime(struct ieee80211_sta *pubsta, u8 tid, + u32 tx_airtime, u32 rx_airtime) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_local *local = sta->sdata->local; + u8 ac = ieee80211_ac_from_tid(tid); + u32 airtime = 0; + + if (sta->local->airtime_flags & AIRTIME_USE_TX) + airtime += tx_airtime; + if (sta->local->airtime_flags & AIRTIME_USE_RX) + airtime += rx_airtime; + + spin_lock_bh(&local->active_txq_lock[ac]); + sta->airtime[ac].tx_airtime += tx_airtime; + sta->airtime[ac].rx_airtime += rx_airtime; + sta->airtime[ac].deficit -= airtime; + spin_unlock_bh(&local->active_txq_lock[ac]); +} +EXPORT_SYMBOL(ieee80211_sta_register_airtime); + int sta_info_move_state(struct sta_info *sta, enum ieee80211_sta_state new_state) { @@ -2188,6 +2209,23 @@ void sta_set_sinfo(struct sta_info *sta, struct station_info *sinfo, sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_FAILED); } + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_RX_DURATION))) { + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + sinfo->rx_duration += sta->airtime[ac].rx_airtime; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_RX_DURATION); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_DURATION))) { + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + sinfo->tx_duration += sta->airtime[ac].tx_airtime; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_DURATION); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_AIRTIME_WEIGHT))) { + sinfo->airtime_weight = sta->airtime_weight; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_AIRTIME_WEIGHT); + } + sinfo->rx_dropped_misc = sta->rx_stats.dropped; if (sta->pcpu_rx_stats) { for_each_possible_cpu(cpu) { diff --git a/net/mac80211/sta_info.h b/net/mac80211/sta_info.h index 8eb29041be54..71f7e4973329 100644 --- a/net/mac80211/sta_info.h +++ b/net/mac80211/sta_info.h @@ -127,6 +127,16 @@ enum ieee80211_agg_stop_reason { AGG_STOP_DESTROY_STA, }; +/* Debugfs flags to enable/disable use of RX/TX airtime in scheduler */ +#define AIRTIME_USE_TX BIT(0) +#define AIRTIME_USE_RX BIT(1) + +struct airtime_info { + u64 rx_airtime; + u64 tx_airtime; + s64 deficit; +}; + struct sta_info; /** @@ -343,6 +353,7 @@ struct ieee80211_fast_rx { /* we use only values in the range 0-100, so pick a large precision */ DECLARE_EWMA(mesh_fail_avg, 20, 8) +DECLARE_EWMA(mesh_tx_rate_avg, 8, 16) /** * struct mesh_sta - mesh STA information @@ -366,6 +377,7 @@ DECLARE_EWMA(mesh_fail_avg, 20, 8) * processed * @connected_to_gate: true if mesh STA has a path to a mesh gate * @fail_avg: moving percentage of failed MSDUs + * @tx_rate_avg: moving average of tx bitrate */ struct mesh_sta { struct timer_list plink_timer; @@ -394,6 +406,8 @@ struct mesh_sta { /* moving percentage of failed MSDUs */ struct ewma_mesh_fail_avg fail_avg; + /* moving average of tx bitrate */ + struct ewma_mesh_tx_rate_avg tx_rate_avg; }; DECLARE_EWMA(signal, 10, 8) @@ -459,6 +473,9 @@ struct ieee80211_sta_rx_stats { * @last_seq_ctrl: last received seq/frag number from this STA (per TID * plus one for non-QoS frames) * @tid_seq: per-TID sequence numbers for sending to this STA + * @airtime: per-AC struct airtime_info describing airtime statistics for this + * station + * @airtime_weight: station weight for airtime fairness calculation purposes * @ampdu_mlme: A-MPDU state machine state * @mesh: mesh STA information * @debugfs_dir: debug filesystem directory dentry @@ -480,10 +497,28 @@ struct ieee80211_sta_rx_stats { * @tdls_chandef: a TDLS peer can have a wider chandef that is compatible to * the BSS one. * @tx_stats: TX statistics + * @tx_stats.packets: # of packets transmitted + * @tx_stats.bytes: # of bytes in all packets transmitted + * @tx_stats.last_rate: last TX rate + * @tx_stats.msdu: # of transmitted MSDUs per TID * @rx_stats: RX statistics + * @rx_stats_avg: averaged RX statistics + * @rx_stats_avg.signal: averaged signal + * @rx_stats_avg.chain_signal: averaged per-chain signal * @pcpu_rx_stats: per-CPU RX statistics, assigned only if the driver needs * this (by advertising the USES_RSS hw flag) * @status_stats: TX status statistics + * @status_stats.filtered: # of filtered frames + * @status_stats.retry_failed: # of frames that failed after retry + * @status_stats.retry_count: # of retries attempted + * @status_stats.lost_packets: # of lost packets + * @status_stats.last_tdls_pkt_time: timestamp of last TDLS packet + * @status_stats.msdu_retries: # of MSDU retries + * @status_stats.msdu_failed: # of failed MSDUs + * @status_stats.last_ack: last ack timestamp (jiffies) + * @status_stats.last_ack_signal: last ACK signal + * @status_stats.ack_signal_filled: last ACK signal validity + * @status_stats.avg_ack_signal: average ACK signal */ struct sta_info { /* General information, mostly static */ @@ -565,6 +600,9 @@ struct sta_info { } tx_stats; u16 tid_seq[IEEE80211_QOS_CTL_TID_MASK + 1]; + struct airtime_info airtime[IEEE80211_NUM_ACS]; + u16 airtime_weight; + /* * Aggregation information, locked with lock. */ diff --git a/net/mac80211/status.c b/net/mac80211/status.c index 3f0b96e1e02f..5b9952b1caf3 100644 --- a/net/mac80211/status.c +++ b/net/mac80211/status.c @@ -823,6 +823,12 @@ static void __ieee80211_tx_status(struct ieee80211_hw *hw, ieee80211_sta_tx_notify(sta->sdata, (void *) skb->data, acked, info->status.tx_time); + if (info->status.tx_time && + wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + ieee80211_sta_register_airtime(&sta->sta, tid, + info->status.tx_time, 0); + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) { if (info->flags & IEEE80211_TX_STAT_ACK) { if (sta->status_stats.lost_packets) diff --git a/net/mac80211/tdls.c b/net/mac80211/tdls.c index 6c647f425e05..d30690d79a58 100644 --- a/net/mac80211/tdls.c +++ b/net/mac80211/tdls.c @@ -5,6 +5,7 @@ * Copyright 2014, Intel Corporation * Copyright 2014 Intel Mobile Communications GmbH * Copyright 2015 - 2016 Intel Deutschland GmbH + * Copyright (C) 2019 Intel Corporation * * This file is GPLv2 as found in COPYING. */ @@ -1716,7 +1717,8 @@ ieee80211_process_tdls_channel_switch_resp(struct ieee80211_sub_if_data *sdata, } ieee802_11_parse_elems(tf->u.chan_switch_resp.variable, - skb->len - baselen, false, &elems); + skb->len - baselen, false, &elems, + NULL, NULL); if (elems.parse_error) { tdls_dbg(sdata, "Invalid IEs in TDLS channel switch resp\n"); ret = -EINVAL; @@ -1828,7 +1830,7 @@ ieee80211_process_tdls_channel_switch_req(struct ieee80211_sub_if_data *sdata, } ieee802_11_parse_elems(tf->u.chan_switch_req.variable, - skb->len - baselen, false, &elems); + skb->len - baselen, false, &elems, NULL, NULL); if (elems.parse_error) { tdls_dbg(sdata, "Invalid IEs in TDLS channel switch req\n"); return -EINVAL; diff --git a/net/mac80211/trace.h b/net/mac80211/trace.h index 35ea0dcb55e6..8ba70d26b82e 100644 --- a/net/mac80211/trace.h +++ b/net/mac80211/trace.h @@ -1,8 +1,8 @@ /* SPDX-License-Identifier: GPL-2.0 */ /* * Portions of this file -* Copyright(c) 2016 Intel Deutschland GmbH -* Copyright (C) 2018 Intel Corporation +* Copyright(c) 2016-2017 Intel Deutschland GmbH +* Copyright (C) 2018 - 2019 Intel Corporation */ #if !defined(__MAC80211_DRIVER_TRACE) || defined(TRACE_HEADER_MULTI_READ) @@ -2452,6 +2452,48 @@ DEFINE_EVENT(local_sdata_evt, drv_post_channel_switch, TP_ARGS(local, sdata) ); +DEFINE_EVENT(local_sdata_evt, drv_abort_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_channel_switch_rx_beacon, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch), + + TP_ARGS(local, sdata, ch_switch), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + CHANDEF_ENTRY + __field(u64, timestamp) + __field(u32, device_timestamp) + __field(bool, block_tx) + __field(u8, count) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + CHANDEF_ASSIGN(&ch_switch->chandef) + __entry->timestamp = ch_switch->timestamp; + __entry->device_timestamp = ch_switch->device_timestamp; + __entry->block_tx = ch_switch->block_tx; + __entry->count = ch_switch->count; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + " received a channel switch beacon to " + CHANDEF_PR_FMT " count:%d block_tx:%d timestamp:%llu", + LOCAL_PR_ARG, VIF_PR_ARG, CHANDEF_PR_ARG, __entry->count, + __entry->block_tx, __entry->timestamp + ) +); + TRACE_EVENT(drv_get_txpower, TP_PROTO(struct ieee80211_local *local, struct ieee80211_sub_if_data *sdata, diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c index f170d6c6629a..8a49a74c0a37 100644 --- a/net/mac80211/tx.c +++ b/net/mac80211/tx.c @@ -1449,6 +1449,7 @@ void ieee80211_txq_init(struct ieee80211_sub_if_data *sdata, codel_vars_init(&txqi->def_cvars); codel_stats_init(&txqi->cstats); __skb_queue_head_init(&txqi->frags); + INIT_LIST_HEAD(&txqi->schedule_order); txqi->txq.vif = &sdata->vif; @@ -1487,8 +1488,14 @@ void ieee80211_txq_purge(struct ieee80211_local *local, struct fq *fq = &local->fq; struct fq_tin *tin = &txqi->tin; + spin_lock_bh(&fq->lock); fq_tin_reset(fq, tin, fq_skb_free_func); ieee80211_purge_tx_queue(&local->hw, &txqi->frags); + spin_unlock_bh(&fq->lock); + + spin_lock_bh(&local->active_txq_lock[txqi->txq.ac]); + list_del_init(&txqi->schedule_order); + spin_unlock_bh(&local->active_txq_lock[txqi->txq.ac]); } void ieee80211_txq_set_params(struct ieee80211_local *local) @@ -1605,7 +1612,7 @@ static bool ieee80211_queue_skb(struct ieee80211_local *local, ieee80211_txq_enqueue(local, txqi, skb); spin_unlock_bh(&fq->lock); - drv_wake_tx_queue(local, txqi); + schedule_and_wake_txq(local, txqi); return true; } @@ -1938,9 +1945,16 @@ static int ieee80211_skb_resize(struct ieee80211_sub_if_data *sdata, int head_need, bool may_encrypt) { struct ieee80211_local *local = sdata->local; + struct ieee80211_hdr *hdr; + bool enc_tailroom; int tail_need = 0; - if (may_encrypt && sdata->crypto_tx_tailroom_needed_cnt) { + hdr = (struct ieee80211_hdr *) skb->data; + enc_tailroom = may_encrypt && + (sdata->crypto_tx_tailroom_needed_cnt || + ieee80211_is_mgmt(hdr->frame_control)); + + if (enc_tailroom) { tail_need = IEEE80211_ENCRYPT_TAILROOM; tail_need -= skb_tailroom(skb); tail_need = max_t(int, tail_need, 0); @@ -1948,8 +1962,7 @@ static int ieee80211_skb_resize(struct ieee80211_sub_if_data *sdata, if (skb_cloned(skb) && (!ieee80211_hw_check(&local->hw, SUPPORTS_CLONED_SKBS) || - !skb_clone_writable(skb, ETH_HLEN) || - (may_encrypt && sdata->crypto_tx_tailroom_needed_cnt))) + !skb_clone_writable(skb, ETH_HLEN) || enc_tailroom)) I802_DEBUG_INC(local->tx_expand_skb_head_cloned); else if (head_need || tail_need) I802_DEBUG_INC(local->tx_expand_skb_head); @@ -3630,6 +3643,151 @@ out: } EXPORT_SYMBOL(ieee80211_tx_dequeue); +struct ieee80211_txq *ieee80211_next_txq(struct ieee80211_hw *hw, u8 ac) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct txq_info *txqi = NULL; + + lockdep_assert_held(&local->active_txq_lock[ac]); + + begin: + txqi = list_first_entry_or_null(&local->active_txqs[ac], + struct txq_info, + schedule_order); + if (!txqi) + return NULL; + + if (txqi->txq.sta) { + struct sta_info *sta = container_of(txqi->txq.sta, + struct sta_info, sta); + + if (sta->airtime[txqi->txq.ac].deficit < 0) { + sta->airtime[txqi->txq.ac].deficit += + sta->airtime_weight; + list_move_tail(&txqi->schedule_order, + &local->active_txqs[txqi->txq.ac]); + goto begin; + } + } + + + if (txqi->schedule_round == local->schedule_round[ac]) + return NULL; + + list_del_init(&txqi->schedule_order); + txqi->schedule_round = local->schedule_round[ac]; + return &txqi->txq; +} +EXPORT_SYMBOL(ieee80211_next_txq); + +void ieee80211_return_txq(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct txq_info *txqi = to_txq_info(txq); + + lockdep_assert_held(&local->active_txq_lock[txq->ac]); + + if (list_empty(&txqi->schedule_order) && + (!skb_queue_empty(&txqi->frags) || txqi->tin.backlog_packets)) { + /* If airtime accounting is active, always enqueue STAs at the + * head of the list to ensure that they only get moved to the + * back by the airtime DRR scheduler once they have a negative + * deficit. A station that already has a negative deficit will + * get immediately moved to the back of the list on the next + * call to ieee80211_next_txq(). + */ + if (txqi->txq.sta && + wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + list_add(&txqi->schedule_order, + &local->active_txqs[txq->ac]); + else + list_add_tail(&txqi->schedule_order, + &local->active_txqs[txq->ac]); + } +} +EXPORT_SYMBOL(ieee80211_return_txq); + +void ieee80211_schedule_txq(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) + __acquires(txq_lock) __releases(txq_lock) +{ + struct ieee80211_local *local = hw_to_local(hw); + + spin_lock_bh(&local->active_txq_lock[txq->ac]); + ieee80211_return_txq(hw, txq); + spin_unlock_bh(&local->active_txq_lock[txq->ac]); +} +EXPORT_SYMBOL(ieee80211_schedule_txq); + +bool ieee80211_txq_may_transmit(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct txq_info *iter, *tmp, *txqi = to_txq_info(txq); + struct sta_info *sta; + u8 ac = txq->ac; + + lockdep_assert_held(&local->active_txq_lock[ac]); + + if (!txqi->txq.sta) + goto out; + + if (list_empty(&txqi->schedule_order)) + goto out; + + list_for_each_entry_safe(iter, tmp, &local->active_txqs[ac], + schedule_order) { + if (iter == txqi) + break; + + if (!iter->txq.sta) { + list_move_tail(&iter->schedule_order, + &local->active_txqs[ac]); + continue; + } + sta = container_of(iter->txq.sta, struct sta_info, sta); + if (sta->airtime[ac].deficit < 0) + sta->airtime[ac].deficit += sta->airtime_weight; + list_move_tail(&iter->schedule_order, &local->active_txqs[ac]); + } + + sta = container_of(txqi->txq.sta, struct sta_info, sta); + if (sta->airtime[ac].deficit >= 0) + goto out; + + sta->airtime[ac].deficit += sta->airtime_weight; + list_move_tail(&txqi->schedule_order, &local->active_txqs[ac]); + + return false; +out: + if (!list_empty(&txqi->schedule_order)) + list_del_init(&txqi->schedule_order); + + return true; +} +EXPORT_SYMBOL(ieee80211_txq_may_transmit); + +void ieee80211_txq_schedule_start(struct ieee80211_hw *hw, u8 ac) + __acquires(txq_lock) +{ + struct ieee80211_local *local = hw_to_local(hw); + + spin_lock_bh(&local->active_txq_lock[ac]); + local->schedule_round[ac]++; +} +EXPORT_SYMBOL(ieee80211_txq_schedule_start); + +void ieee80211_txq_schedule_end(struct ieee80211_hw *hw, u8 ac) + __releases(txq_lock) +{ + struct ieee80211_local *local = hw_to_local(hw); + + spin_unlock_bh(&local->active_txq_lock[ac]); +} +EXPORT_SYMBOL(ieee80211_txq_schedule_end); + void __ieee80211_subif_start_xmit(struct sk_buff *skb, struct net_device *dev, u32 info_flags) diff --git a/net/mac80211/util.c b/net/mac80211/util.c index d0eb38b890aa..4c1655972565 100644 --- a/net/mac80211/util.c +++ b/net/mac80211/util.c @@ -5,7 +5,7 @@ * Copyright 2007 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright (C) 2015-2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018-2019 Intel Corporation * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 as @@ -891,33 +891,24 @@ void ieee80211_queue_delayed_work(struct ieee80211_hw *hw, } EXPORT_SYMBOL(ieee80211_queue_delayed_work); -u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, - struct ieee802_11_elems *elems, - u64 filter, u32 crc) +static u32 +_ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, + struct ieee802_11_elems *elems, + u64 filter, u32 crc, u8 *transmitter_bssid, + u8 *bss_bssid) { - size_t left = len; - const u8 *pos = start; + const struct element *elem, *sub; bool calc_crc = filter != 0; DECLARE_BITMAP(seen_elems, 256); const u8 *ie; bitmap_zero(seen_elems, 256); - memset(elems, 0, sizeof(*elems)); - elems->ie_start = start; - elems->total_len = len; - while (left >= 2) { - u8 id, elen; + for_each_element(elem, start, len) { bool elem_parse_failed; - - id = *pos++; - elen = *pos++; - left -= 2; - - if (elen > left) { - elems->parse_error = true; - break; - } + u8 id = elem->id; + u8 elen = elem->datalen; + const u8 *pos = elem->data; switch (id) { case WLAN_EID_SSID: @@ -960,8 +951,6 @@ u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, */ if (test_bit(id, seen_elems)) { elems->parse_error = true; - left -= elen; - pos += elen; continue; } break; @@ -1219,6 +1208,57 @@ u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, if (elen >= sizeof(*elems->max_idle_period_ie)) elems->max_idle_period_ie = (void *)pos; break; + case WLAN_EID_MULTIPLE_BSSID: + if (!bss_bssid || !transmitter_bssid || elen < 4) + break; + + elems->max_bssid_indicator = pos[0]; + + for_each_element(sub, pos + 1, elen - 1) { + u8 sub_len = sub->datalen; + u8 new_bssid[ETH_ALEN]; + const u8 *index; + + /* + * we only expect the "non-transmitted BSSID + * profile" subelement (subelement id 0) + */ + if (sub->id != 0 || sub->datalen < 4) { + /* not a valid BSS profile */ + continue; + } + + if (sub->data[0] != WLAN_EID_NON_TX_BSSID_CAP || + sub->data[1] != 2) { + /* The first element of the + * Nontransmitted BSSID Profile is not + * the Nontransmitted BSSID Capability + * element. + */ + continue; + } + + /* found a Nontransmitted BSSID Profile */ + index = cfg80211_find_ie(WLAN_EID_MULTI_BSSID_IDX, + sub->data, sub_len); + if (!index || index[1] < 1 || index[2] == 0) { + /* Invalid MBSSID Index element */ + continue; + } + + cfg80211_gen_new_bssid(transmitter_bssid, + pos[0], + index[2], + new_bssid); + if (ether_addr_equal(new_bssid, bss_bssid)) { + elems->nontransmitted_bssid_profile = + (void *)sub; + elems->bssid_index_len = index[1]; + elems->bssid_index = (void *)&index[2]; + break; + } + } + break; case WLAN_EID_EXTENSION: if (pos[0] == WLAN_EID_EXT_HE_MU_EDCA && elen >= (sizeof(*elems->mu_edca_param_set) + 1)) { @@ -1234,6 +1274,14 @@ u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, elems->he_operation = (void *)&pos[1]; } else if (pos[0] == WLAN_EID_EXT_UORA && elen >= 1) { elems->uora_element = (void *)&pos[1]; + } else if (pos[0] == + WLAN_EID_EXT_MAX_CHANNEL_SWITCH_TIME && + elen == 4) { + elems->max_channel_switch_time = pos + 1; + } else if (pos[0] == + WLAN_EID_EXT_MULTIPLE_BSSID_CONFIGURATION && + elen == 3) { + elems->mbssid_config_ie = (void *)&pos[1]; } break; default: @@ -1244,17 +1292,56 @@ u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, elems->parse_error = true; else __set_bit(id, seen_elems); - - left -= elen; - pos += elen; } - if (left != 0) + if (!for_each_element_completed(elem, start, len)) elems->parse_error = true; return crc; } +u32 ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, + struct ieee802_11_elems *elems, + u64 filter, u32 crc, u8 *transmitter_bssid, + u8 *bss_bssid) +{ + memset(elems, 0, sizeof(*elems)); + elems->ie_start = start; + elems->total_len = len; + + crc = _ieee802_11_parse_elems_crc(start, len, action, elems, filter, + crc, transmitter_bssid, bss_bssid); + + /* Override with nontransmitted profile, if found */ + if (transmitter_bssid && elems->nontransmitted_bssid_profile) { + const u8 *profile = elems->nontransmitted_bssid_profile; + + _ieee802_11_parse_elems_crc(&profile[2], profile[1], + action, elems, 0, 0, + transmitter_bssid, bss_bssid); + } + + if (elems->tim && !elems->parse_error) { + const struct ieee80211_tim_ie *tim_ie = elems->tim; + + elems->dtim_period = tim_ie->dtim_period; + elems->dtim_count = tim_ie->dtim_count; + } + + /* Override DTIM period and count if needed */ + if (elems->bssid_index && + elems->bssid_index_len >= + offsetofend(struct ieee80211_bssid_index, dtim_period)) + elems->dtim_period = elems->bssid_index->dtim_period; + + if (elems->bssid_index && + elems->bssid_index_len >= + offsetofend(struct ieee80211_bssid_index, dtim_count)) + elems->dtim_count = elems->bssid_index->dtim_count; + + return crc; +} + void ieee80211_regulatory_limit_wmm_params(struct ieee80211_sub_if_data *sdata, struct ieee80211_tx_queue_params *qparam, int ac) @@ -2146,6 +2233,10 @@ int ieee80211_reconfig(struct ieee80211_local *local) case NL80211_IFTYPE_AP_VLAN: case NL80211_IFTYPE_MONITOR: break; + case NL80211_IFTYPE_ADHOC: + if (sdata->vif.bss_conf.ibss_joined) + WARN_ON(drv_join_ibss(local, sdata)); + /* fall through */ default: ieee80211_reconfig_stations(sdata); /* fall through */ diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index 7d55d4c04088..f7c544592ec8 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -1209,21 +1209,57 @@ static const struct nla_policy devconf_mpls_policy[NETCONFA_MAX + 1] = { [NETCONFA_IFINDEX] = { .len = sizeof(int) }, }; +static int mpls_netconf_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid header for netconf get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(struct netconfmsg), tb, + NETCONFA_MAX, devconf_mpls_policy, extack); + + err = nlmsg_parse_strict(nlh, sizeof(struct netconfmsg), tb, + NETCONFA_MAX, devconf_mpls_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETCONFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETCONFA_IFINDEX: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request"); + return -EINVAL; + } + } + + return 0; +} + static int mpls_netconf_get_devconf(struct sk_buff *in_skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { struct net *net = sock_net(in_skb->sk); struct nlattr *tb[NETCONFA_MAX + 1]; - struct netconfmsg *ncm; struct net_device *dev; struct mpls_dev *mdev; struct sk_buff *skb; int ifindex; int err; - err = nlmsg_parse(nlh, sizeof(*ncm), tb, NETCONFA_MAX, - devconf_mpls_policy, extack); + err = mpls_netconf_valid_get_req(in_skb, nlh, tb, extack); if (err < 0) goto errout; @@ -1838,6 +1874,9 @@ static int rtm_to_route_config(struct sk_buff *skb, goto errout; break; } + case RTA_GATEWAY: + NL_SET_ERR_MSG(extack, "MPLS does not support RTA_GATEWAY attribute"); + goto errout; case RTA_VIA: { if (nla_get_via(nla, &cfg->rc_via_alen, @@ -2236,6 +2275,64 @@ errout: rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err); } +static int mpls_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid header for get route request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_dst_len && rtm->rtm_dst_len != 20) || + rtm->rtm_src_len || rtm->rtm_tos || rtm->rtm_table || + rtm->rtm_protocol || rtm->rtm_scope || rtm->rtm_type) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request"); + return -EINVAL; + } + if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid flags for get route request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + if (err) + return err; + + if ((tb[RTA_DST] || tb[RTA_NEWDST]) && !rtm->rtm_dst_len) { + NL_SET_ERR_MSG_MOD(extack, "rtm_dst_len must be 20 for MPLS"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_DST: + case RTA_NEWDST: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request"); + return -EINVAL; + } + } + + return 0; +} + static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, struct netlink_ext_ack *extack) { @@ -2255,8 +2352,7 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, u8 n_labels; int err; - err = nlmsg_parse(in_nlh, sizeof(*rtm), tb, RTA_MAX, - rtm_mpls_policy, extack); + err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack); if (err < 0) goto errout; diff --git a/net/mpls/mpls_iptunnel.c b/net/mpls/mpls_iptunnel.c index 94f53a9b7d1a..dda8930f20e7 100644 --- a/net/mpls/mpls_iptunnel.c +++ b/net/mpls/mpls_iptunnel.c @@ -183,8 +183,8 @@ static int mpls_build_state(struct nlattr *nla, &n_labels, NULL, extack)) return -EINVAL; - newts = lwtunnel_state_alloc(sizeof(*tun_encap_info) + - n_labels * sizeof(u32)); + newts = lwtunnel_state_alloc(struct_size(tun_encap_info, label, + n_labels)); if (!newts) return -ENOMEM; diff --git a/net/netfilter/Kconfig b/net/netfilter/Kconfig index beb3a69ce1d4..d43ffb09939b 100644 --- a/net/netfilter/Kconfig +++ b/net/netfilter/Kconfig @@ -174,7 +174,7 @@ config NF_CT_PROTO_DCCP If unsure, say Y. config NF_CT_PROTO_GRE - tristate + bool config NF_CT_PROTO_SCTP bool 'SCTP protocol connection tracking support' @@ -396,7 +396,13 @@ config NETFILTER_NETLINK_GLUE_CT the enqueued via NFNETLINK. config NF_NAT - tristate + tristate "Network Address Translation support" + depends on NF_CONNTRACK + default m if NETFILTER_ADVANCED=n + help + The NAT option allows masquerading, port forwarding and other + forms of full Network Address Port Translation. This can be + controlled by iptables, ip6tables or nft. config NF_NAT_NEEDED bool @@ -431,6 +437,9 @@ config NF_NAT_TFTP config NF_NAT_REDIRECT bool +config NF_NAT_MASQUERADE + bool + config NETFILTER_SYNPROXY tristate @@ -523,6 +532,7 @@ config NFT_LIMIT config NFT_MASQ depends on NF_CONNTRACK depends on NF_NAT + select NF_NAT_MASQUERADE tristate "Netfilter nf_tables masquerade support" help This option adds the "masquerade" expression that you can use @@ -532,6 +542,7 @@ config NFT_REDIR depends on NF_CONNTRACK depends on NF_NAT tristate "Netfilter nf_tables redirect support" + select NF_NAT_REDIRECT help This options adds the "redirect" expression that you can use to perform NAT in the redirect flavour. @@ -539,6 +550,7 @@ config NFT_REDIR config NFT_NAT depends on NF_CONNTRACK select NF_NAT + depends on NF_TABLES_IPV4 || NF_TABLES_IPV6 tristate "Netfilter nf_tables nat module" help This option adds the "nat" expression that you can use to perform diff --git a/net/netfilter/Makefile b/net/netfilter/Makefile index 1ae65a314d7a..4894a85cdd0b 100644 --- a/net/netfilter/Makefile +++ b/net/netfilter/Makefile @@ -13,6 +13,7 @@ nf_conntrack-$(CONFIG_NF_CONNTRACK_EVENTS) += nf_conntrack_ecache.o nf_conntrack-$(CONFIG_NF_CONNTRACK_LABELS) += nf_conntrack_labels.o nf_conntrack-$(CONFIG_NF_CT_PROTO_DCCP) += nf_conntrack_proto_dccp.o nf_conntrack-$(CONFIG_NF_CT_PROTO_SCTP) += nf_conntrack_proto_sctp.o +nf_conntrack-$(CONFIG_NF_CT_PROTO_GRE) += nf_conntrack_proto_gre.o obj-$(CONFIG_NETFILTER) = netfilter.o @@ -25,8 +26,6 @@ obj-$(CONFIG_NETFILTER_NETLINK_OSF) += nfnetlink_osf.o # connection tracking obj-$(CONFIG_NF_CONNTRACK) += nf_conntrack.o -obj-$(CONFIG_NF_CT_PROTO_GRE) += nf_conntrack_proto_gre.o - # netlink interface for nf_conntrack obj-$(CONFIG_NF_CT_NETLINK) += nf_conntrack_netlink.o obj-$(CONFIG_NF_CT_NETLINK_TIMEOUT) += nfnetlink_cttimeout.o @@ -57,6 +56,7 @@ obj-$(CONFIG_NF_LOG_NETDEV) += nf_log_netdev.o obj-$(CONFIG_NF_NAT) += nf_nat.o nf_nat-$(CONFIG_NF_NAT_REDIRECT) += nf_nat_redirect.o +nf_nat-$(CONFIG_NF_NAT_MASQUERADE) += nf_nat_masquerade.o # NAT helpers obj-$(CONFIG_NF_NAT_AMANDA) += nf_nat_amanda.o @@ -110,6 +110,8 @@ obj-$(CONFIG_NFT_OSF) += nft_osf.o obj-$(CONFIG_NFT_TPROXY) += nft_tproxy.o obj-$(CONFIG_NFT_XFRM) += nft_xfrm.o +obj-$(CONFIG_NFT_NAT) += nft_chain_nat.o + # nf_tables netdev obj-$(CONFIG_NFT_DUP_NETDEV) += nft_dup_netdev.o obj-$(CONFIG_NFT_FWD_NETDEV) += nft_fwd_netdev.o diff --git a/net/netfilter/ipvs/Kconfig b/net/netfilter/ipvs/Kconfig index cad48d07c818..8401cefd9f65 100644 --- a/net/netfilter/ipvs/Kconfig +++ b/net/netfilter/ipvs/Kconfig @@ -29,6 +29,7 @@ config IP_VS_IPV6 bool "IPv6 support for IPVS" depends on IPV6 = y || IP_VS = IPV6 select IP6_NF_IPTABLES + select NF_DEFRAG_IPV6 ---help--- Add IPv6 support to IPVS. diff --git a/net/netfilter/ipvs/ip_vs_core.c b/net/netfilter/ipvs/ip_vs_core.c index fe9abf3cc10a..43bbaa32b1d6 100644 --- a/net/netfilter/ipvs/ip_vs_core.c +++ b/net/netfilter/ipvs/ip_vs_core.c @@ -53,6 +53,7 @@ #endif #include <net/ip_vs.h> +#include <linux/indirect_call_wrapper.h> EXPORT_SYMBOL(register_ip_vs_scheduler); @@ -70,6 +71,29 @@ EXPORT_SYMBOL(ip_vs_get_debug_level); #endif EXPORT_SYMBOL(ip_vs_new_conn_out); +#ifdef CONFIG_IP_VS_PROTO_TCP +INDIRECT_CALLABLE_DECLARE(int + tcp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph)); +#endif + +#ifdef CONFIG_IP_VS_PROTO_UDP +INDIRECT_CALLABLE_DECLARE(int + udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph)); +#endif + +#if defined(CONFIG_IP_VS_PROTO_TCP) && defined(CONFIG_IP_VS_PROTO_UDP) +#define SNAT_CALL(f, ...) \ + INDIRECT_CALL_2(f, tcp_snat_handler, udp_snat_handler, __VA_ARGS__) +#elif defined(CONFIG_IP_VS_PROTO_TCP) +#define SNAT_CALL(f, ...) INDIRECT_CALL_1(f, tcp_snat_handler, __VA_ARGS__) +#elif defined(CONFIG_IP_VS_PROTO_UDP) +#define SNAT_CALL(f, ...) INDIRECT_CALL_1(f, udp_snat_handler, __VA_ARGS__) +#else +#define SNAT_CALL(f, ...) f(__VA_ARGS__) +#endif + static unsigned int ip_vs_net_id __read_mostly; /* netns cnt used for uniqueness */ static atomic_t ipvs_netns_cnt = ATOMIC_INIT(0); @@ -478,7 +502,9 @@ ip_vs_schedule(struct ip_vs_service *svc, struct sk_buff *skb, */ if ((!skb->dev || skb->dev->flags & IFF_LOOPBACK)) { iph->hdr_flags ^= IP_VS_HDR_INVERSE; - cp = pp->conn_in_get(svc->ipvs, svc->af, skb, iph); + cp = INDIRECT_CALL_1(pp->conn_in_get, + ip_vs_conn_in_get_proto, svc->ipvs, + svc->af, skb, iph); iph->hdr_flags ^= IP_VS_HDR_INVERSE; if (cp) { @@ -972,7 +998,8 @@ static int ip_vs_out_icmp(struct netns_ipvs *ipvs, struct sk_buff *skb, ip_vs_fill_iph_skb_icmp(AF_INET, skb, offset, true, &ciph); /* The embedded headers contain source and dest in reverse order */ - cp = pp->conn_out_get(ipvs, AF_INET, skb, &ciph); + cp = INDIRECT_CALL_1(pp->conn_out_get, ip_vs_conn_out_get_proto, + ipvs, AF_INET, skb, &ciph); if (!cp) return NF_ACCEPT; @@ -1028,7 +1055,8 @@ static int ip_vs_out_icmp_v6(struct netns_ipvs *ipvs, struct sk_buff *skb, return NF_ACCEPT; /* The embedded headers contain source and dest in reverse order */ - cp = pp->conn_out_get(ipvs, AF_INET6, skb, &ciph); + cp = INDIRECT_CALL_1(pp->conn_out_get, ip_vs_conn_out_get_proto, + ipvs, AF_INET6, skb, &ciph); if (!cp) return NF_ACCEPT; @@ -1263,7 +1291,8 @@ handle_response(int af, struct sk_buff *skb, struct ip_vs_proto_data *pd, goto drop; /* mangle the packet */ - if (pp->snat_handler && !pp->snat_handler(skb, pp, cp, iph)) + if (pp->snat_handler && + !SNAT_CALL(pp->snat_handler, skb, pp, cp, iph)) goto drop; #ifdef CONFIG_IP_VS_IPV6 @@ -1389,7 +1418,8 @@ ip_vs_out(struct netns_ipvs *ipvs, unsigned int hooknum, struct sk_buff *skb, in /* * Check if the packet belongs to an existing entry */ - cp = pp->conn_out_get(ipvs, af, skb, &iph); + cp = INDIRECT_CALL_1(pp->conn_out_get, ip_vs_conn_out_get_proto, + ipvs, af, skb, &iph); if (likely(cp)) { if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) @@ -1536,14 +1566,12 @@ ip_vs_try_to_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, /* sorry, all this trouble for a no-hit :) */ IP_VS_DBG_PKT(12, af, pp, skb, iph->off, "ip_vs_in: packet continues traversal as normal"); - if (iph->fragoffs) { - /* Fragment that couldn't be mapped to a conn entry - * is missing module nf_defrag_ipv6 - */ - IP_VS_DBG_RL("Unhandled frag, load nf_defrag_ipv6\n"); + + /* Fragment couldn't be mapped to a conn entry */ + if (iph->fragoffs) IP_VS_DBG_PKT(7, af, pp, skb, iph->off, "unhandled fragment"); - } + *verdict = NF_ACCEPT; return 0; } @@ -1644,7 +1672,8 @@ ip_vs_in_icmp(struct netns_ipvs *ipvs, struct sk_buff *skb, int *related, /* The embedded headers contain source and dest in reverse order. * For IPIP this is error for request, not for reply. */ - cp = pp->conn_in_get(ipvs, AF_INET, skb, &ciph); + cp = INDIRECT_CALL_1(pp->conn_in_get, ip_vs_conn_in_get_proto, + ipvs, AF_INET, skb, &ciph); if (!cp) { int v; @@ -1796,7 +1825,8 @@ static int ip_vs_in_icmp_v6(struct netns_ipvs *ipvs, struct sk_buff *skb, /* The embedded headers contain source and dest in reverse order * if not from localhost */ - cp = pp->conn_in_get(ipvs, AF_INET6, skb, &ciph); + cp = INDIRECT_CALL_1(pp->conn_in_get, ip_vs_conn_in_get_proto, + ipvs, AF_INET6, skb, &ciph); if (!cp) { int v; @@ -1925,7 +1955,8 @@ ip_vs_in(struct netns_ipvs *ipvs, unsigned int hooknum, struct sk_buff *skb, int /* * Check if the packet belongs to an existing connection entry */ - cp = pp->conn_in_get(ipvs, af, skb, &iph); + cp = INDIRECT_CALL_1(pp->conn_in_get, ip_vs_conn_in_get_proto, + ipvs, af, skb, &iph); conn_reuse_mode = sysctl_conn_reuse_mode(ipvs); if (conn_reuse_mode && !iph.fragoffs && is_new_conn(skb, &iph) && cp) { diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c index 432141f04af3..053cd96b9c76 100644 --- a/net/netfilter/ipvs/ip_vs_ctl.c +++ b/net/netfilter/ipvs/ip_vs_ctl.c @@ -43,6 +43,7 @@ #ifdef CONFIG_IP_VS_IPV6 #include <net/ipv6.h> #include <net/ip6_route.h> +#include <net/netfilter/ipv6/nf_defrag_ipv6.h> #endif #include <net/route.h> #include <net/sock.h> @@ -900,11 +901,17 @@ ip_vs_new_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest, #ifdef CONFIG_IP_VS_IPV6 if (udest->af == AF_INET6) { + int ret; + atype = ipv6_addr_type(&udest->addr.in6); if ((!(atype & IPV6_ADDR_UNICAST) || atype & IPV6_ADDR_LINKLOCAL) && !__ip_vs_addr_is_local_v6(svc->ipvs->net, &udest->addr.in6)) return -EINVAL; + + ret = nf_defrag_ipv6_enable(svc->ipvs->net); + if (ret) + return ret; } else #endif { @@ -1228,6 +1235,10 @@ ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, ret = -EINVAL; goto out_err; } + + ret = nf_defrag_ipv6_enable(ipvs->net); + if (ret) + goto out_err; } #endif @@ -2221,6 +2232,18 @@ static int ip_vs_set_timeout(struct netns_ipvs *ipvs, struct ip_vs_timeout_user u->udp_timeout); #ifdef CONFIG_IP_VS_PROTO_TCP + if (u->tcp_timeout < 0 || u->tcp_timeout > (INT_MAX / HZ) || + u->tcp_fin_timeout < 0 || u->tcp_fin_timeout > (INT_MAX / HZ)) { + return -EINVAL; + } +#endif + +#ifdef CONFIG_IP_VS_PROTO_UDP + if (u->udp_timeout < 0 || u->udp_timeout > (INT_MAX / HZ)) + return -EINVAL; +#endif + +#ifdef CONFIG_IP_VS_PROTO_TCP if (u->tcp_timeout) { pd = ip_vs_proto_data_get(ipvs, IPPROTO_TCP); pd->timeout_table[IP_VS_TCP_S_ESTABLISHED] @@ -2722,8 +2745,7 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) int size; get = (struct ip_vs_get_services *)arg; - size = sizeof(*get) + - sizeof(struct ip_vs_service_entry) * get->num_services; + size = struct_size(get, entrytable, get->num_services); if (*len != size) { pr_err("length: %u != %u\n", *len, size); ret = -EINVAL; @@ -2764,8 +2786,7 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) int size; get = (struct ip_vs_get_dests *)arg; - size = sizeof(*get) + - sizeof(struct ip_vs_dest_entry) * get->num_dests; + size = struct_size(get, entrytable, get->num_dests); if (*len != size) { pr_err("length: %u != %u\n", *len, size); ret = -EINVAL; @@ -3065,7 +3086,7 @@ static bool ip_vs_is_af_valid(int af) static int ip_vs_genl_parse_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *usvc, - struct nlattr *nla, int full_entry, + struct nlattr *nla, bool full_entry, struct ip_vs_service **ret_svc) { struct nlattr *attrs[IPVS_SVC_ATTR_MAX + 1]; @@ -3152,7 +3173,7 @@ static struct ip_vs_service *ip_vs_genl_find_service(struct netns_ipvs *ipvs, struct ip_vs_service *svc; int ret; - ret = ip_vs_genl_parse_service(ipvs, &usvc, nla, 0, &svc); + ret = ip_vs_genl_parse_service(ipvs, &usvc, nla, false, &svc); return ret ? ERR_PTR(ret) : svc; } @@ -3262,7 +3283,7 @@ out_err: } static int ip_vs_genl_parse_dest(struct ip_vs_dest_user_kern *udest, - struct nlattr *nla, int full_entry) + struct nlattr *nla, bool full_entry) { struct nlattr *attrs[IPVS_DEST_ATTR_MAX + 1]; struct nlattr *nla_addr, *nla_port; @@ -3524,11 +3545,11 @@ out: static int ip_vs_genl_set_cmd(struct sk_buff *skb, struct genl_info *info) { + bool need_full_svc = false, need_full_dest = false; struct ip_vs_service *svc = NULL; struct ip_vs_service_user_kern usvc; struct ip_vs_dest_user_kern udest; int ret = 0, cmd; - int need_full_svc = 0, need_full_dest = 0; struct net *net = sock_net(skb->sk); struct netns_ipvs *ipvs = net_ipvs(net); @@ -3552,7 +3573,7 @@ static int ip_vs_genl_set_cmd(struct sk_buff *skb, struct genl_info *info) * received a valid one. We need a full service specification when * adding / editing a service. Only identifying members otherwise. */ if (cmd == IPVS_CMD_NEW_SERVICE || cmd == IPVS_CMD_SET_SERVICE) - need_full_svc = 1; + need_full_svc = true; ret = ip_vs_genl_parse_service(ipvs, &usvc, info->attrs[IPVS_CMD_ATTR_SERVICE], @@ -3572,7 +3593,7 @@ static int ip_vs_genl_set_cmd(struct sk_buff *skb, struct genl_info *info) if (cmd == IPVS_CMD_NEW_DEST || cmd == IPVS_CMD_SET_DEST || cmd == IPVS_CMD_DEL_DEST) { if (cmd != IPVS_CMD_DEL_DEST) - need_full_dest = 1; + need_full_dest = true; ret = ip_vs_genl_parse_dest(&udest, info->attrs[IPVS_CMD_ATTR_DEST], diff --git a/net/netfilter/ipvs/ip_vs_ftp.c b/net/netfilter/ipvs/ip_vs_ftp.c index 4398a72edec5..fe69d46ff779 100644 --- a/net/netfilter/ipvs/ip_vs_ftp.c +++ b/net/netfilter/ipvs/ip_vs_ftp.c @@ -124,7 +124,7 @@ static int ip_vs_ftp_get_addrport(char *data, char *data_limit, } s = data + plen; if (skip) { - int found = 0; + bool found = false; for (;; s++) { if (s == data_limit) @@ -136,7 +136,7 @@ static int ip_vs_ftp_get_addrport(char *data, char *data_limit, if (!ext && isdigit(*s)) break; if (*s == skip) - found = 1; + found = true; } else if (*s != skip) { break; } diff --git a/net/netfilter/ipvs/ip_vs_proto_ah_esp.c b/net/netfilter/ipvs/ip_vs_proto_ah_esp.c index 5320d39976e1..480598cb0f05 100644 --- a/net/netfilter/ipvs/ip_vs_proto_ah_esp.c +++ b/net/netfilter/ipvs/ip_vs_proto_ah_esp.c @@ -129,7 +129,6 @@ struct ip_vs_protocol ip_vs_protocol_ah = { .conn_out_get = ah_esp_conn_out_get, .snat_handler = NULL, .dnat_handler = NULL, - .csum_check = NULL, .state_transition = NULL, .register_app = NULL, .unregister_app = NULL, @@ -152,7 +151,6 @@ struct ip_vs_protocol ip_vs_protocol_esp = { .conn_out_get = ah_esp_conn_out_get, .snat_handler = NULL, .dnat_handler = NULL, - .csum_check = NULL, .state_transition = NULL, .register_app = NULL, .unregister_app = NULL, diff --git a/net/netfilter/ipvs/ip_vs_proto_sctp.c b/net/netfilter/ipvs/ip_vs_proto_sctp.c index b0cd7d08f2a7..b58ddb7dffd1 100644 --- a/net/netfilter/ipvs/ip_vs_proto_sctp.c +++ b/net/netfilter/ipvs/ip_vs_proto_sctp.c @@ -10,6 +10,9 @@ #include <net/ip_vs.h> static int +sctp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp); + +static int sctp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, struct ip_vs_proto_data *pd, int *verdict, struct ip_vs_conn **cpp, @@ -105,7 +108,7 @@ sctp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, int ret; /* Some checks before mangling */ - if (pp->csum_check && !pp->csum_check(cp->af, skb, pp)) + if (!sctp_csum_check(cp->af, skb, pp)) return 0; /* Call application helper if needed */ @@ -152,7 +155,7 @@ sctp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, int ret; /* Some checks before mangling */ - if (pp->csum_check && !pp->csum_check(cp->af, skb, pp)) + if (!sctp_csum_check(cp->af, skb, pp)) return 0; /* Call application helper if needed */ @@ -183,7 +186,7 @@ static int sctp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp) { unsigned int sctphoff; - struct sctphdr *sh, _sctph; + struct sctphdr *sh; __le32 cmp, val; #ifdef CONFIG_IP_VS_IPV6 @@ -193,10 +196,7 @@ sctp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp) #endif sctphoff = ip_hdrlen(skb); - sh = skb_header_pointer(skb, sctphoff, sizeof(_sctph), &_sctph); - if (sh == NULL) - return 0; - + sh = (struct sctphdr *)(skb->data + sctphoff); cmp = sh->checksum; val = sctp_compute_cksum(skb, sctphoff); @@ -587,7 +587,6 @@ struct ip_vs_protocol ip_vs_protocol_sctp = { .conn_out_get = ip_vs_conn_out_get_proto, .snat_handler = sctp_snat_handler, .dnat_handler = sctp_dnat_handler, - .csum_check = sctp_csum_check, .state_name = sctp_state_name, .state_transition = sctp_state_transition, .app_conn_bind = sctp_app_conn_bind, diff --git a/net/netfilter/ipvs/ip_vs_proto_tcp.c b/net/netfilter/ipvs/ip_vs_proto_tcp.c index 1770fc6ce960..00ce07dda980 100644 --- a/net/netfilter/ipvs/ip_vs_proto_tcp.c +++ b/net/netfilter/ipvs/ip_vs_proto_tcp.c @@ -28,10 +28,14 @@ #include <net/ip6_checksum.h> #include <linux/netfilter.h> #include <linux/netfilter_ipv4.h> +#include <linux/indirect_call_wrapper.h> #include <net/ip_vs.h> static int +tcp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp); + +static int tcp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, struct ip_vs_proto_data *pd, int *verdict, struct ip_vs_conn **cpp, @@ -143,14 +147,14 @@ tcp_partial_csum_update(int af, struct tcphdr *tcph, } -static int +INDIRECT_CALLABLE_SCOPE int tcp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) { struct tcphdr *tcph; unsigned int tcphoff = iph->len; + bool payload_csum = false; int oldlen; - int payload_csum = 0; #ifdef CONFIG_IP_VS_IPV6 if (cp->af == AF_INET6 && iph->fragoffs) @@ -166,7 +170,7 @@ tcp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, int ret; /* Some checks before mangling */ - if (pp->csum_check && !pp->csum_check(cp->af, skb, pp)) + if (!tcp_csum_check(cp->af, skb, pp)) return 0; /* Call application helper if needed */ @@ -176,7 +180,7 @@ tcp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, if (ret == 1) oldlen = skb->len - tcphoff; else - payload_csum = 1; + payload_csum = true; } tcph = (void *)skb_network_header(skb) + tcphoff; @@ -192,7 +196,7 @@ tcp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, tcp_fast_csum_update(cp->af, tcph, &cp->daddr, &cp->vaddr, cp->dport, cp->vport); if (skb->ip_summed == CHECKSUM_COMPLETE) - skb->ip_summed = (cp->app && pp->csum_check) ? + skb->ip_summed = cp->app ? CHECKSUM_UNNECESSARY : CHECKSUM_NONE; } else { /* full checksum calculation */ @@ -227,8 +231,8 @@ tcp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, { struct tcphdr *tcph; unsigned int tcphoff = iph->len; + bool payload_csum = false; int oldlen; - int payload_csum = 0; #ifdef CONFIG_IP_VS_IPV6 if (cp->af == AF_INET6 && iph->fragoffs) @@ -244,7 +248,7 @@ tcp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, int ret; /* Some checks before mangling */ - if (pp->csum_check && !pp->csum_check(cp->af, skb, pp)) + if (!tcp_csum_check(cp->af, skb, pp)) return 0; /* @@ -257,7 +261,7 @@ tcp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, if (ret == 1) oldlen = skb->len - tcphoff; else - payload_csum = 1; + payload_csum = true; } tcph = (void *)skb_network_header(skb) + tcphoff; @@ -275,7 +279,7 @@ tcp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, tcp_fast_csum_update(cp->af, tcph, &cp->vaddr, &cp->daddr, cp->vport, cp->dport); if (skb->ip_summed == CHECKSUM_COMPLETE) - skb->ip_summed = (cp->app && pp->csum_check) ? + skb->ip_summed = cp->app ? CHECKSUM_UNNECESSARY : CHECKSUM_NONE; } else { /* full checksum calculation */ @@ -736,7 +740,6 @@ struct ip_vs_protocol ip_vs_protocol_tcp = { .conn_out_get = ip_vs_conn_out_get_proto, .snat_handler = tcp_snat_handler, .dnat_handler = tcp_dnat_handler, - .csum_check = tcp_csum_check, .state_name = tcp_state_name, .state_transition = tcp_state_transition, .app_conn_bind = tcp_app_conn_bind, diff --git a/net/netfilter/ipvs/ip_vs_proto_udp.c b/net/netfilter/ipvs/ip_vs_proto_udp.c index 0f53c49025f8..92c078abcb3e 100644 --- a/net/netfilter/ipvs/ip_vs_proto_udp.c +++ b/net/netfilter/ipvs/ip_vs_proto_udp.c @@ -23,12 +23,16 @@ #include <linux/netfilter.h> #include <linux/netfilter_ipv4.h> #include <linux/udp.h> +#include <linux/indirect_call_wrapper.h> #include <net/ip_vs.h> #include <net/ip.h> #include <net/ip6_checksum.h> static int +udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp); + +static int udp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, struct ip_vs_proto_data *pd, int *verdict, struct ip_vs_conn **cpp, @@ -133,14 +137,14 @@ udp_partial_csum_update(int af, struct udphdr *uhdr, } -static int +INDIRECT_CALLABLE_SCOPE int udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) { struct udphdr *udph; unsigned int udphoff = iph->len; + bool payload_csum = false; int oldlen; - int payload_csum = 0; #ifdef CONFIG_IP_VS_IPV6 if (cp->af == AF_INET6 && iph->fragoffs) @@ -156,7 +160,7 @@ udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, int ret; /* Some checks before mangling */ - if (pp->csum_check && !pp->csum_check(cp->af, skb, pp)) + if (!udp_csum_check(cp->af, skb, pp)) return 0; /* @@ -168,7 +172,7 @@ udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, if (ret == 1) oldlen = skb->len - udphoff; else - payload_csum = 1; + payload_csum = true; } udph = (void *)skb_network_header(skb) + udphoff; @@ -186,7 +190,7 @@ udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, udp_fast_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr, cp->dport, cp->vport); if (skb->ip_summed == CHECKSUM_COMPLETE) - skb->ip_summed = (cp->app && pp->csum_check) ? + skb->ip_summed = cp->app ? CHECKSUM_UNNECESSARY : CHECKSUM_NONE; } else { /* full checksum calculation */ @@ -222,8 +226,8 @@ udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, { struct udphdr *udph; unsigned int udphoff = iph->len; + bool payload_csum = false; int oldlen; - int payload_csum = 0; #ifdef CONFIG_IP_VS_IPV6 if (cp->af == AF_INET6 && iph->fragoffs) @@ -239,7 +243,7 @@ udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, int ret; /* Some checks before mangling */ - if (pp->csum_check && !pp->csum_check(cp->af, skb, pp)) + if (!udp_csum_check(cp->af, skb, pp)) return 0; /* @@ -252,7 +256,7 @@ udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, if (ret == 1) oldlen = skb->len - udphoff; else - payload_csum = 1; + payload_csum = true; } udph = (void *)skb_network_header(skb) + udphoff; @@ -270,7 +274,7 @@ udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, udp_fast_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr, cp->vport, cp->dport); if (skb->ip_summed == CHECKSUM_COMPLETE) - skb->ip_summed = (cp->app && pp->csum_check) ? + skb->ip_summed = cp->app ? CHECKSUM_UNNECESSARY : CHECKSUM_NONE; } else { /* full checksum calculation */ @@ -494,7 +498,6 @@ struct ip_vs_protocol ip_vs_protocol_udp = { .conn_out_get = ip_vs_conn_out_get_proto, .snat_handler = udp_snat_handler, .dnat_handler = udp_dnat_handler, - .csum_check = udp_csum_check, .state_transition = udp_state_transition, .state_name = udp_state_name, .register_app = udp_register_app, diff --git a/net/netfilter/ipvs/ip_vs_xmit.c b/net/netfilter/ipvs/ip_vs_xmit.c index 473cce2a5231..175349fcf91f 100644 --- a/net/netfilter/ipvs/ip_vs_xmit.c +++ b/net/netfilter/ipvs/ip_vs_xmit.c @@ -126,7 +126,7 @@ static struct rtable *do_output_route4(struct net *net, __be32 daddr, { struct flowi4 fl4; struct rtable *rt; - int loop = 0; + bool loop = false; memset(&fl4, 0, sizeof(fl4)); fl4.daddr = daddr; @@ -149,7 +149,7 @@ retry: ip_rt_put(rt); *saddr = fl4.saddr; flowi4_update_output(&fl4, 0, 0, daddr, fl4.saddr); - loop++; + loop = true; goto retry; } *saddr = fl4.saddr; diff --git a/net/netfilter/nf_conntrack_amanda.c b/net/netfilter/nf_conntrack_amanda.c index 20edd589fe06..f2681ec5b5f6 100644 --- a/net/netfilter/nf_conntrack_amanda.c +++ b/net/netfilter/nf_conntrack_amanda.c @@ -54,6 +54,7 @@ enum amanda_strings { SEARCH_DATA, SEARCH_MESG, SEARCH_INDEX, + SEARCH_STATE, }; static struct { @@ -81,6 +82,10 @@ static struct { .string = "INDEX ", .len = 6, }, + [SEARCH_STATE] = { + .string = "STATE ", + .len = 6, + }, }; static int amanda_help(struct sk_buff *skb, @@ -124,7 +129,7 @@ static int amanda_help(struct sk_buff *skb, goto out; stop += start; - for (i = SEARCH_DATA; i <= SEARCH_INDEX; i++) { + for (i = SEARCH_DATA; i <= SEARCH_STATE; i++) { off = skb_find_text(skb, start, stop, search[i].ts); if (off == UINT_MAX) continue; @@ -168,7 +173,7 @@ out: } static const struct nf_conntrack_expect_policy amanda_exp_policy = { - .max_expected = 3, + .max_expected = 4, .timeout = 180, }; diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c index 741b533148ba..82bfbeef46af 100644 --- a/net/netfilter/nf_conntrack_core.c +++ b/net/netfilter/nf_conntrack_core.c @@ -51,7 +51,6 @@ #include <net/netfilter/nf_conntrack_labels.h> #include <net/netfilter/nf_conntrack_synproxy.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> #include <net/netfilter/nf_nat_helper.h> #include <net/netns/hash.h> #include <net/ip.h> @@ -222,6 +221,24 @@ static u32 hash_conntrack(const struct net *net, return scale_hash(hash_conntrack_raw(tuple, net)); } +static bool nf_ct_get_tuple_ports(const struct sk_buff *skb, + unsigned int dataoff, + struct nf_conntrack_tuple *tuple) +{ struct { + __be16 sport; + __be16 dport; + } _inet_hdr, *inet_hdr; + + /* Actually only need first 4 bytes to get ports. */ + inet_hdr = skb_header_pointer(skb, dataoff, sizeof(_inet_hdr), &_inet_hdr); + if (!inet_hdr) + return false; + + tuple->src.u.udp.port = inet_hdr->sport; + tuple->dst.u.udp.port = inet_hdr->dport; + return true; +} + static bool nf_ct_get_tuple(const struct sk_buff *skb, unsigned int nhoff, @@ -229,16 +246,11 @@ nf_ct_get_tuple(const struct sk_buff *skb, u_int16_t l3num, u_int8_t protonum, struct net *net, - struct nf_conntrack_tuple *tuple, - const struct nf_conntrack_l4proto *l4proto) + struct nf_conntrack_tuple *tuple) { unsigned int size; const __be32 *ap; __be32 _addrs[8]; - struct { - __be16 sport; - __be16 dport; - } _inet_hdr, *inet_hdr; memset(tuple, 0, sizeof(*tuple)); @@ -274,16 +286,36 @@ nf_ct_get_tuple(const struct sk_buff *skb, tuple->dst.protonum = protonum; tuple->dst.dir = IP_CT_DIR_ORIGINAL; - if (unlikely(l4proto->pkt_to_tuple)) - return l4proto->pkt_to_tuple(skb, dataoff, net, tuple); - - /* Actually only need first 4 bytes to get ports. */ - inet_hdr = skb_header_pointer(skb, dataoff, sizeof(_inet_hdr), &_inet_hdr); - if (!inet_hdr) - return false; + switch (protonum) { +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + return icmpv6_pkt_to_tuple(skb, dataoff, net, tuple); +#endif + case IPPROTO_ICMP: + return icmp_pkt_to_tuple(skb, dataoff, net, tuple); +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + return gre_pkt_to_tuple(skb, dataoff, net, tuple); +#endif + case IPPROTO_TCP: + case IPPROTO_UDP: /* fallthrough */ + return nf_ct_get_tuple_ports(skb, dataoff, tuple); +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + case IPPROTO_UDPLITE: + return nf_ct_get_tuple_ports(skb, dataoff, tuple); +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + case IPPROTO_SCTP: + return nf_ct_get_tuple_ports(skb, dataoff, tuple); +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + case IPPROTO_DCCP: + return nf_ct_get_tuple_ports(skb, dataoff, tuple); +#endif + default: + break; + } - tuple->src.u.udp.port = inet_hdr->sport; - tuple->dst.u.udp.port = inet_hdr->dport; return true; } @@ -366,33 +398,20 @@ bool nf_ct_get_tuplepr(const struct sk_buff *skb, unsigned int nhoff, u_int16_t l3num, struct net *net, struct nf_conntrack_tuple *tuple) { - const struct nf_conntrack_l4proto *l4proto; u8 protonum; int protoff; - int ret; - - rcu_read_lock(); protoff = get_l4proto(skb, nhoff, l3num, &protonum); - if (protoff <= 0) { - rcu_read_unlock(); + if (protoff <= 0) return false; - } - - l4proto = __nf_ct_l4proto_find(protonum); - - ret = nf_ct_get_tuple(skb, nhoff, protoff, l3num, protonum, net, tuple, - l4proto); - rcu_read_unlock(); - return ret; + return nf_ct_get_tuple(skb, nhoff, protoff, l3num, protonum, net, tuple); } EXPORT_SYMBOL_GPL(nf_ct_get_tuplepr); bool nf_ct_invert_tuple(struct nf_conntrack_tuple *inverse, - const struct nf_conntrack_tuple *orig, - const struct nf_conntrack_l4proto *l4proto) + const struct nf_conntrack_tuple *orig) { memset(inverse, 0, sizeof(*inverse)); @@ -415,8 +434,14 @@ nf_ct_invert_tuple(struct nf_conntrack_tuple *inverse, inverse->dst.protonum = orig->dst.protonum; - if (unlikely(l4proto->invert_tuple)) - return l4proto->invert_tuple(inverse, orig); + switch (orig->dst.protonum) { + case IPPROTO_ICMP: + return nf_conntrack_invert_icmp_tuple(inverse, orig); +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + return nf_conntrack_invert_icmpv6_tuple(inverse, orig); +#endif + } inverse->src.u.all = orig->dst.u.all; inverse->dst.u.all = orig->src.u.all; @@ -526,11 +551,20 @@ void nf_ct_tmpl_free(struct nf_conn *tmpl) } EXPORT_SYMBOL_GPL(nf_ct_tmpl_free); +static void destroy_gre_conntrack(struct nf_conn *ct) +{ +#ifdef CONFIG_NF_CT_PROTO_GRE + struct nf_conn *master = ct->master; + + if (master) + nf_ct_gre_keymap_destroy(master); +#endif +} + static void destroy_conntrack(struct nf_conntrack *nfct) { struct nf_conn *ct = (struct nf_conn *)nfct; - const struct nf_conntrack_l4proto *l4proto; pr_debug("destroy_conntrack(%p)\n", ct); WARN_ON(atomic_read(&nfct->use) != 0); @@ -539,9 +573,9 @@ destroy_conntrack(struct nf_conntrack *nfct) nf_ct_tmpl_free(ct); return; } - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); - if (l4proto->destroy) - l4proto->destroy(ct); + + if (unlikely(nf_ct_protonum(ct) == IPPROTO_GRE)) + destroy_gre_conntrack(ct); local_bh_disable(); /* Expectations will have been removed in clean_from_lists, @@ -840,7 +874,7 @@ static int nf_ct_resolve_clash(struct net *net, struct sk_buff *skb, enum ip_conntrack_info oldinfo; struct nf_conn *loser_ct = nf_ct_get(skb, &oldinfo); - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->allow_clash && !nf_ct_is_dying(ct) && atomic_inc_not_zero(&ct->ct_general.use)) { @@ -901,10 +935,18 @@ __nf_conntrack_confirm(struct sk_buff *skb) * REJECT will give spurious warnings here. */ - /* No external references means no one else could have - * confirmed us. + /* Another skb with the same unconfirmed conntrack may + * win the race. This may happen for bridge(br_flood) + * or broadcast/multicast packets do skb_clone with + * unconfirmed conntrack. */ - WARN_ON(nf_ct_is_confirmed(ct)); + if (unlikely(nf_ct_is_confirmed(ct))) { + WARN_ON_ONCE(1); + nf_conntrack_double_unlock(hash, reply_hash); + local_bh_enable(); + return NF_DROP; + } + pr_debug("Confirming conntrack %p\n", ct); /* We have to check the DYING flag after unlink to prevent * a race against nf_ct_get_next_corpse() possibly called from @@ -1007,6 +1049,22 @@ nf_conntrack_tuple_taken(const struct nf_conntrack_tuple *tuple, } if (nf_ct_key_equal(h, tuple, zone, net)) { + /* Tuple is taken already, so caller will need to find + * a new source port to use. + * + * Only exception: + * If the *original tuples* are identical, then both + * conntracks refer to the same flow. + * This is a rare situation, it can occur e.g. when + * more than one UDP packet is sent from same socket + * in different threads. + * + * Let nf_ct_resolve_clash() deal with this later. + */ + if (nf_ct_tuple_equal(&ignored_conntrack->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple)) + continue; + NF_CT_STAT_INC_ATOMIC(net, found); rcu_read_unlock(); return 1; @@ -1112,7 +1170,7 @@ static bool gc_worker_can_early_drop(const struct nf_conn *ct) if (!test_bit(IPS_ASSURED_BIT, &ct->status)) return true; - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->can_early_drop && l4proto->can_early_drop(ct)) return true; @@ -1342,7 +1400,6 @@ EXPORT_SYMBOL_GPL(nf_conntrack_free); static noinline struct nf_conntrack_tuple_hash * init_conntrack(struct net *net, struct nf_conn *tmpl, const struct nf_conntrack_tuple *tuple, - const struct nf_conntrack_l4proto *l4proto, struct sk_buff *skb, unsigned int dataoff, u32 hash) { @@ -1355,7 +1412,7 @@ init_conntrack(struct net *net, struct nf_conn *tmpl, struct nf_conn_timeout *timeout_ext; struct nf_conntrack_zone tmp; - if (!nf_ct_invert_tuple(&repl_tuple, tuple, l4proto)) { + if (!nf_ct_invert_tuple(&repl_tuple, tuple)) { pr_debug("Can't invert tuple.\n"); return NULL; } @@ -1437,7 +1494,6 @@ resolve_normal_ct(struct nf_conn *tmpl, struct sk_buff *skb, unsigned int dataoff, u_int8_t protonum, - const struct nf_conntrack_l4proto *l4proto, const struct nf_hook_state *state) { const struct nf_conntrack_zone *zone; @@ -1450,7 +1506,7 @@ resolve_normal_ct(struct nf_conn *tmpl, if (!nf_ct_get_tuple(skb, skb_network_offset(skb), dataoff, state->pf, protonum, state->net, - &tuple, l4proto)) { + &tuple)) { pr_debug("Can't get tuple\n"); return 0; } @@ -1460,7 +1516,7 @@ resolve_normal_ct(struct nf_conn *tmpl, hash = hash_conntrack_raw(&tuple, state->net); h = __nf_conntrack_find_get(state->net, zone, &tuple, hash); if (!h) { - h = init_conntrack(state->net, tmpl, &tuple, l4proto, + h = init_conntrack(state->net, tmpl, &tuple, skb, dataoff, hash); if (!h) return 0; @@ -1522,10 +1578,66 @@ nf_conntrack_handle_icmp(struct nf_conn *tmpl, return ret; } +static int generic_packet(struct nf_conn *ct, struct sk_buff *skb, + enum ip_conntrack_info ctinfo) +{ + const unsigned int *timeout = nf_ct_timeout_lookup(ct); + + if (!timeout) + timeout = &nf_generic_pernet(nf_ct_net(ct))->timeout; + + nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); + return NF_ACCEPT; +} + +/* Returns verdict for packet, or -1 for invalid. */ +static int nf_conntrack_handle_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + switch (nf_ct_protonum(ct)) { + case IPPROTO_TCP: + return nf_conntrack_tcp_packet(ct, skb, dataoff, + ctinfo, state); + case IPPROTO_UDP: + return nf_conntrack_udp_packet(ct, skb, dataoff, + ctinfo, state); + case IPPROTO_ICMP: + return nf_conntrack_icmp_packet(ct, skb, ctinfo, state); +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + return nf_conntrack_icmpv6_packet(ct, skb, ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + case IPPROTO_UDPLITE: + return nf_conntrack_udplite_packet(ct, skb, dataoff, + ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + case IPPROTO_SCTP: + return nf_conntrack_sctp_packet(ct, skb, dataoff, + ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + case IPPROTO_DCCP: + return nf_conntrack_dccp_packet(ct, skb, dataoff, + ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + return nf_conntrack_gre_packet(ct, skb, dataoff, + ctinfo, state); +#endif + } + + return generic_packet(ct, skb, ctinfo); +} + unsigned int nf_conntrack_in(struct sk_buff *skb, const struct nf_hook_state *state) { - const struct nf_conntrack_l4proto *l4proto; enum ip_conntrack_info ctinfo; struct nf_conn *ct, *tmpl; u_int8_t protonum; @@ -1552,8 +1664,6 @@ nf_conntrack_in(struct sk_buff *skb, const struct nf_hook_state *state) goto out; } - l4proto = __nf_ct_l4proto_find(protonum); - if (protonum == IPPROTO_ICMP || protonum == IPPROTO_ICMPV6) { ret = nf_conntrack_handle_icmp(tmpl, skb, dataoff, protonum, state); @@ -1567,7 +1677,7 @@ nf_conntrack_in(struct sk_buff *skb, const struct nf_hook_state *state) } repeat: ret = resolve_normal_ct(tmpl, skb, dataoff, - protonum, l4proto, state); + protonum, state); if (ret < 0) { /* Too stressed to deal. */ NF_CT_STAT_INC_ATOMIC(state->net, drop); @@ -1583,7 +1693,7 @@ repeat: goto out; } - ret = l4proto->packet(ct, skb, dataoff, ctinfo, state); + ret = nf_conntrack_handle_packet(ct, skb, dataoff, ctinfo, state); if (ret <= 0) { /* Invalid: inverse of the return code tells * the netfilter core what to do */ @@ -1614,19 +1724,6 @@ out: } EXPORT_SYMBOL_GPL(nf_conntrack_in); -bool nf_ct_invert_tuplepr(struct nf_conntrack_tuple *inverse, - const struct nf_conntrack_tuple *orig) -{ - bool ret; - - rcu_read_lock(); - ret = nf_ct_invert_tuple(inverse, orig, - __nf_ct_l4proto_find(orig->dst.protonum)); - rcu_read_unlock(); - return ret; -} -EXPORT_SYMBOL_GPL(nf_ct_invert_tuplepr); - /* Alter reply tuple (maybe alter helper). This is for NAT, and is implicitly racy: see __nf_conntrack_confirm */ void nf_conntrack_alter_reply(struct nf_conn *ct, @@ -1654,11 +1751,9 @@ EXPORT_SYMBOL_GPL(nf_conntrack_alter_reply); void __nf_ct_refresh_acct(struct nf_conn *ct, enum ip_conntrack_info ctinfo, const struct sk_buff *skb, - unsigned long extra_jiffies, - int do_acct) + u32 extra_jiffies, + bool do_acct) { - WARN_ON(!skb); - /* Only update if this is not a fixed timeout */ if (test_bit(IPS_FIXED_TIMEOUT_BIT, &ct->status)) goto acct; @@ -1667,7 +1762,8 @@ void __nf_ct_refresh_acct(struct nf_conn *ct, if (nf_ct_is_confirmed(ct)) extra_jiffies += nfct_time_stamp; - ct->timeout = extra_jiffies; + if (ct->timeout != extra_jiffies) + ct->timeout = extra_jiffies; acct: if (do_acct) nf_ct_acct_update(ct, ctinfo, skb->len); @@ -1757,7 +1853,6 @@ static void nf_conntrack_attach(struct sk_buff *nskb, const struct sk_buff *skb) static int nf_conntrack_update(struct net *net, struct sk_buff *skb) { - const struct nf_conntrack_l4proto *l4proto; struct nf_conntrack_tuple_hash *h; struct nf_conntrack_tuple tuple; enum ip_conntrack_info ctinfo; @@ -1778,10 +1873,8 @@ static int nf_conntrack_update(struct net *net, struct sk_buff *skb) if (dataoff <= 0) return -1; - l4proto = nf_ct_l4proto_find_get(l4num); - if (!nf_ct_get_tuple(skb, skb_network_offset(skb), dataoff, l3num, - l4num, net, &tuple, l4proto)) + l4num, net, &tuple)) return -1; if (ct->status & IPS_SRC_NAT) { @@ -2387,6 +2480,7 @@ int nf_conntrack_init_net(struct net *net) int cpu; BUILD_BUG_ON(IP_CT_UNTRACKED == IP_CT_NUMBER); + BUILD_BUG_ON_NOT_POWER_OF_2(CONNTRACK_LOCKS); atomic_set(&net->ct.count, 0); net->ct.pcpu_lists = alloc_percpu(struct ct_pcpu); @@ -2413,15 +2507,10 @@ int nf_conntrack_init_net(struct net *net) nf_conntrack_tstamp_pernet_init(net); nf_conntrack_ecache_pernet_init(net); nf_conntrack_helper_pernet_init(net); + nf_conntrack_proto_pernet_init(net); - ret = nf_conntrack_proto_pernet_init(net); - if (ret < 0) - goto err_proto; return 0; -err_proto: - nf_conntrack_ecache_pernet_fini(net); - nf_conntrack_expect_pernet_fini(net); err_expect: free_percpu(net->ct.stat); err_pcpu_lists: diff --git a/net/netfilter/nf_conntrack_expect.c b/net/netfilter/nf_conntrack_expect.c index 3034038bfdf0..334d6e5b7762 100644 --- a/net/netfilter/nf_conntrack_expect.c +++ b/net/netfilter/nf_conntrack_expect.c @@ -610,7 +610,7 @@ static int exp_seq_show(struct seq_file *s, void *v) expect->tuple.src.l3num, expect->tuple.dst.protonum); print_tuple(s, &expect->tuple, - __nf_ct_l4proto_find(expect->tuple.dst.protonum)); + nf_ct_l4proto_find(expect->tuple.dst.protonum)); if (expect->flags & NF_CT_EXPECT_PERMANENT) { seq_puts(s, "PERMANENT"); diff --git a/net/netfilter/nf_conntrack_netlink.c b/net/netfilter/nf_conntrack_netlink.c index 1213beb5a714..66c596d287a5 100644 --- a/net/netfilter/nf_conntrack_netlink.c +++ b/net/netfilter/nf_conntrack_netlink.c @@ -46,7 +46,7 @@ #include <net/netfilter/nf_conntrack_labels.h> #include <net/netfilter/nf_conntrack_synproxy.h> #ifdef CONFIG_NF_NAT_NEEDED -#include <net/netfilter/nf_nat_core.h> +#include <net/netfilter/nf_nat.h> #include <net/netfilter/nf_nat_helper.h> #endif @@ -134,7 +134,7 @@ static int ctnetlink_dump_tuples(struct sk_buff *skb, ret = ctnetlink_dump_tuples_ip(skb, tuple); if (ret >= 0) { - l4proto = __nf_ct_l4proto_find(tuple->dst.protonum); + l4proto = nf_ct_l4proto_find(tuple->dst.protonum); ret = ctnetlink_dump_tuples_proto(skb, tuple, l4proto); } rcu_read_unlock(); @@ -182,7 +182,7 @@ static int ctnetlink_dump_protoinfo(struct sk_buff *skb, struct nf_conn *ct) struct nlattr *nest_proto; int ret; - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); if (!l4proto->to_nlattr) return 0; @@ -590,7 +590,7 @@ static size_t ctnetlink_proto_size(const struct nf_conn *ct) len = nla_policy_len(cta_ip_nla_policy, CTA_IP_MAX + 1); len *= 3u; /* ORIG, REPLY, MASTER */ - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); len += l4proto->nlattr_size; if (l4proto->nlattr_tuple_size) { len4 = l4proto->nlattr_tuple_size(); @@ -1059,7 +1059,7 @@ static int ctnetlink_parse_tuple_proto(struct nlattr *attr, tuple->dst.protonum = nla_get_u8(tb[CTA_PROTO_NUM]); rcu_read_lock(); - l4proto = __nf_ct_l4proto_find(tuple->dst.protonum); + l4proto = nf_ct_l4proto_find(tuple->dst.protonum); if (likely(l4proto->nlattr_to_tuple)) { ret = nla_validate_nested(attr, CTA_PROTO_MAX, @@ -1722,11 +1722,9 @@ static int ctnetlink_change_protoinfo(struct nf_conn *ct, if (err < 0) return err; - rcu_read_lock(); - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->from_nlattr) err = l4proto->from_nlattr(tb, ct); - rcu_read_unlock(); return err; } @@ -2676,8 +2674,8 @@ static int ctnetlink_exp_dump_mask(struct sk_buff *skb, rcu_read_lock(); ret = ctnetlink_dump_tuples_ip(skb, &m); if (ret >= 0) { - l4proto = __nf_ct_l4proto_find(tuple->dst.protonum); - ret = ctnetlink_dump_tuples_proto(skb, &m, l4proto); + l4proto = nf_ct_l4proto_find(tuple->dst.protonum); + ret = ctnetlink_dump_tuples_proto(skb, &m, l4proto); } rcu_read_unlock(); diff --git a/net/netfilter/nf_conntrack_pptp.c b/net/netfilter/nf_conntrack_pptp.c index 11562f2a08bb..976f1dcb97f0 100644 --- a/net/netfilter/nf_conntrack_pptp.c +++ b/net/netfilter/nf_conntrack_pptp.c @@ -121,7 +121,7 @@ static void pptp_expectfn(struct nf_conn *ct, struct nf_conntrack_expect *exp_other; /* obviously this tuple inversion only works until you do NAT */ - nf_ct_invert_tuplepr(&inv_t, &exp->tuple); + nf_ct_invert_tuple(&inv_t, &exp->tuple); pr_debug("trying to unexpect other dir: "); nf_ct_dump_tuple(&inv_t); diff --git a/net/netfilter/nf_conntrack_proto.c b/net/netfilter/nf_conntrack_proto.c index 859f5d07a915..b9403a266a2e 100644 --- a/net/netfilter/nf_conntrack_proto.c +++ b/net/netfilter/nf_conntrack_proto.c @@ -43,40 +43,9 @@ extern unsigned int nf_conntrack_net_id; -static struct nf_conntrack_l4proto __rcu *nf_ct_protos[MAX_NF_CT_PROTO + 1] __read_mostly; - static DEFINE_MUTEX(nf_ct_proto_mutex); #ifdef CONFIG_SYSCTL -static int -nf_ct_register_sysctl(struct net *net, - struct ctl_table_header **header, - const char *path, - struct ctl_table *table) -{ - if (*header == NULL) { - *header = register_net_sysctl(net, path, table); - if (*header == NULL) - return -ENOMEM; - } - - return 0; -} - -static void -nf_ct_unregister_sysctl(struct ctl_table_header **header, - struct ctl_table **table, - unsigned int users) -{ - if (users > 0) - return; - - unregister_net_sysctl_table(*header); - kfree(*table); - *header = NULL; - *table = NULL; -} - __printf(5, 6) void nf_l4proto_log_invalid(const struct sk_buff *skb, struct net *net, @@ -124,295 +93,82 @@ void nf_ct_l4proto_log_invalid(const struct sk_buff *skb, EXPORT_SYMBOL_GPL(nf_ct_l4proto_log_invalid); #endif -const struct nf_conntrack_l4proto *__nf_ct_l4proto_find(u8 l4proto) -{ - if (unlikely(l4proto >= ARRAY_SIZE(nf_ct_protos))) - return &nf_conntrack_l4proto_generic; - - return rcu_dereference(nf_ct_protos[l4proto]); -} -EXPORT_SYMBOL_GPL(__nf_ct_l4proto_find); - -const struct nf_conntrack_l4proto *nf_ct_l4proto_find_get(u8 l4num) -{ - const struct nf_conntrack_l4proto *p; - - rcu_read_lock(); - p = __nf_ct_l4proto_find(l4num); - if (!try_module_get(p->me)) - p = &nf_conntrack_l4proto_generic; - rcu_read_unlock(); - - return p; -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_find_get); - -void nf_ct_l4proto_put(const struct nf_conntrack_l4proto *p) -{ - module_put(p->me); -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_put); - -static int kill_l4proto(struct nf_conn *i, void *data) -{ - const struct nf_conntrack_l4proto *l4proto; - l4proto = data; - return nf_ct_protonum(i) == l4proto->l4proto; -} - -static struct nf_proto_net *nf_ct_l4proto_net(struct net *net, - const struct nf_conntrack_l4proto *l4proto) -{ - if (l4proto->get_net_proto) { - /* statically built-in protocols use static per-net */ - return l4proto->get_net_proto(net); - } else if (l4proto->net_id) { - /* ... and loadable protocols use dynamic per-net */ - return net_generic(net, *l4proto->net_id); - } - return NULL; -} - -static -int nf_ct_l4proto_register_sysctl(struct net *net, - struct nf_proto_net *pn) -{ - int err = 0; - -#ifdef CONFIG_SYSCTL - if (pn->ctl_table != NULL) { - err = nf_ct_register_sysctl(net, - &pn->ctl_table_header, - "net/netfilter", - pn->ctl_table); - if (err < 0) { - if (!pn->users) { - kfree(pn->ctl_table); - pn->ctl_table = NULL; - } - } - } -#endif /* CONFIG_SYSCTL */ - return err; -} - -static -void nf_ct_l4proto_unregister_sysctl(struct nf_proto_net *pn) -{ -#ifdef CONFIG_SYSCTL - if (pn->ctl_table_header != NULL) - nf_ct_unregister_sysctl(&pn->ctl_table_header, - &pn->ctl_table, - pn->users); -#endif /* CONFIG_SYSCTL */ -} - -/* FIXME: Allow NULL functions and sub in pointers to generic for - them. --RR */ -int nf_ct_l4proto_register_one(const struct nf_conntrack_l4proto *l4proto) -{ - int ret = 0; - - if ((l4proto->to_nlattr && l4proto->nlattr_size == 0) || - (l4proto->tuple_to_nlattr && !l4proto->nlattr_tuple_size)) - return -EINVAL; - - mutex_lock(&nf_ct_proto_mutex); - if (rcu_dereference_protected( - nf_ct_protos[l4proto->l4proto], - lockdep_is_held(&nf_ct_proto_mutex) - ) != &nf_conntrack_l4proto_generic) { - ret = -EBUSY; - goto out_unlock; - } - - rcu_assign_pointer(nf_ct_protos[l4proto->l4proto], l4proto); -out_unlock: - mutex_unlock(&nf_ct_proto_mutex); - return ret; -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_register_one); - -int nf_ct_l4proto_pernet_register_one(struct net *net, - const struct nf_conntrack_l4proto *l4proto) -{ - int ret = 0; - struct nf_proto_net *pn = NULL; - - if (l4proto->init_net) { - ret = l4proto->init_net(net); - if (ret < 0) - goto out; - } - - pn = nf_ct_l4proto_net(net, l4proto); - if (pn == NULL) - goto out; - - ret = nf_ct_l4proto_register_sysctl(net, pn); - if (ret < 0) - goto out; - - pn->users++; -out: - return ret; -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_register_one); - -static void __nf_ct_l4proto_unregister_one(const struct nf_conntrack_l4proto *l4proto) - -{ - BUG_ON(l4proto->l4proto >= ARRAY_SIZE(nf_ct_protos)); - - BUG_ON(rcu_dereference_protected( - nf_ct_protos[l4proto->l4proto], - lockdep_is_held(&nf_ct_proto_mutex) - ) != l4proto); - rcu_assign_pointer(nf_ct_protos[l4proto->l4proto], - &nf_conntrack_l4proto_generic); -} - -void nf_ct_l4proto_unregister_one(const struct nf_conntrack_l4proto *l4proto) -{ - mutex_lock(&nf_ct_proto_mutex); - __nf_ct_l4proto_unregister_one(l4proto); - mutex_unlock(&nf_ct_proto_mutex); - - synchronize_net(); - /* Remove all contrack entries for this protocol */ - nf_ct_iterate_destroy(kill_l4proto, (void *)l4proto); -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_unregister_one); - -void nf_ct_l4proto_pernet_unregister_one(struct net *net, - const struct nf_conntrack_l4proto *l4proto) -{ - struct nf_proto_net *pn = nf_ct_l4proto_net(net, l4proto); - - if (pn == NULL) - return; - - pn->users--; - nf_ct_l4proto_unregister_sysctl(pn); -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_unregister_one); - -static void -nf_ct_l4proto_unregister(const struct nf_conntrack_l4proto * const l4proto[], - unsigned int num_proto) -{ - int i; - - mutex_lock(&nf_ct_proto_mutex); - for (i = 0; i < num_proto; i++) - __nf_ct_l4proto_unregister_one(l4proto[i]); - mutex_unlock(&nf_ct_proto_mutex); - - synchronize_net(); - - for (i = 0; i < num_proto; i++) - nf_ct_iterate_destroy(kill_l4proto, (void *)l4proto[i]); -} - -static int -nf_ct_l4proto_register(const struct nf_conntrack_l4proto * const l4proto[], - unsigned int num_proto) -{ - int ret = -EINVAL; - unsigned int i; - - for (i = 0; i < num_proto; i++) { - ret = nf_ct_l4proto_register_one(l4proto[i]); - if (ret < 0) - break; - } - if (i != num_proto) { - pr_err("nf_conntrack: can't register l4 %d proto.\n", - l4proto[i]->l4proto); - nf_ct_l4proto_unregister(l4proto, i); - } - return ret; -} - -int nf_ct_l4proto_pernet_register(struct net *net, - const struct nf_conntrack_l4proto *const l4proto[], - unsigned int num_proto) +const struct nf_conntrack_l4proto *nf_ct_l4proto_find(u8 l4proto) { - int ret = -EINVAL; - unsigned int i; - - for (i = 0; i < num_proto; i++) { - ret = nf_ct_l4proto_pernet_register_one(net, l4proto[i]); - if (ret < 0) - break; - } - if (i != num_proto) { - pr_err("nf_conntrack %d: pernet registration failed\n", - l4proto[i]->l4proto); - nf_ct_l4proto_pernet_unregister(net, l4proto, i); + switch (l4proto) { + case IPPROTO_UDP: return &nf_conntrack_l4proto_udp; + case IPPROTO_TCP: return &nf_conntrack_l4proto_tcp; + case IPPROTO_ICMP: return &nf_conntrack_l4proto_icmp; +#ifdef CONFIG_NF_CT_PROTO_DCCP + case IPPROTO_DCCP: return &nf_conntrack_l4proto_dccp; +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + case IPPROTO_SCTP: return &nf_conntrack_l4proto_sctp; +#endif +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + case IPPROTO_UDPLITE: return &nf_conntrack_l4proto_udplite; +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: return &nf_conntrack_l4proto_gre; +#endif +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: return &nf_conntrack_l4proto_icmpv6; +#endif /* CONFIG_IPV6 */ } - return ret; -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_register); -void nf_ct_l4proto_pernet_unregister(struct net *net, - const struct nf_conntrack_l4proto *const l4proto[], - unsigned int num_proto) -{ - while (num_proto-- != 0) - nf_ct_l4proto_pernet_unregister_one(net, l4proto[num_proto]); -} -EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_unregister); + return &nf_conntrack_l4proto_generic; +}; +EXPORT_SYMBOL_GPL(nf_ct_l4proto_find); -static unsigned int ipv4_helper(void *priv, - struct sk_buff *skb, - const struct nf_hook_state *state) +static unsigned int nf_confirm(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) { - struct nf_conn *ct; - enum ip_conntrack_info ctinfo; const struct nf_conn_help *help; - const struct nf_conntrack_helper *helper; - - /* This is where we call the helper: as the packet goes out. */ - ct = nf_ct_get(skb, &ctinfo); - if (!ct || ctinfo == IP_CT_RELATED_REPLY) - return NF_ACCEPT; help = nfct_help(ct); - if (!help) - return NF_ACCEPT; + if (help) { + const struct nf_conntrack_helper *helper; + int ret; + + /* rcu_read_lock()ed by nf_hook_thresh */ + helper = rcu_dereference(help->helper); + if (helper) { + ret = helper->help(skb, + protoff, + ct, ctinfo); + if (ret != NF_ACCEPT) + return ret; + } + } - /* rcu_read_lock()ed by nf_hook_thresh */ - helper = rcu_dereference(help->helper); - if (!helper) - return NF_ACCEPT; + if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && + !nf_is_loopback_packet(skb)) { + if (!nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) { + NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); + return NF_DROP; + } + } - return helper->help(skb, skb_network_offset(skb) + ip_hdrlen(skb), - ct, ctinfo); + /* We've seen it coming out the other side: confirm it */ + return nf_conntrack_confirm(skb); } static unsigned int ipv4_confirm(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { - struct nf_conn *ct; enum ip_conntrack_info ctinfo; + struct nf_conn *ct; ct = nf_ct_get(skb, &ctinfo); if (!ct || ctinfo == IP_CT_RELATED_REPLY) - goto out; + return nf_conntrack_confirm(skb); - /* adjust seqs for loopback traffic only in outgoing direction */ - if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && - !nf_is_loopback_packet(skb)) { - if (!nf_ct_seq_adjust(skb, ct, ctinfo, ip_hdrlen(skb))) { - NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); - return NF_DROP; - } - } -out: - /* We've seen it coming out the other side: confirm it */ - return nf_conntrack_confirm(skb); + return nf_confirm(skb, + skb_network_offset(skb) + ip_hdrlen(skb), + ct, ctinfo); } static unsigned int ipv4_conntrack_in(void *priv, @@ -461,24 +217,12 @@ static const struct nf_hook_ops ipv4_conntrack_ops[] = { .priority = NF_IP_PRI_CONNTRACK, }, { - .hook = ipv4_helper, - .pf = NFPROTO_IPV4, - .hooknum = NF_INET_POST_ROUTING, - .priority = NF_IP_PRI_CONNTRACK_HELPER, - }, - { .hook = ipv4_confirm, .pf = NFPROTO_IPV4, .hooknum = NF_INET_POST_ROUTING, .priority = NF_IP_PRI_CONNTRACK_CONFIRM, }, { - .hook = ipv4_helper, - .pf = NFPROTO_IPV4, - .hooknum = NF_INET_LOCAL_IN, - .priority = NF_IP_PRI_CONNTRACK_HELPER, - }, - { .hook = ipv4_confirm, .pf = NFPROTO_IPV4, .hooknum = NF_INET_LOCAL_IN, @@ -623,31 +367,21 @@ static unsigned int ipv6_confirm(void *priv, struct nf_conn *ct; enum ip_conntrack_info ctinfo; unsigned char pnum = ipv6_hdr(skb)->nexthdr; - int protoff; __be16 frag_off; + int protoff; ct = nf_ct_get(skb, &ctinfo); if (!ct || ctinfo == IP_CT_RELATED_REPLY) - goto out; + return nf_conntrack_confirm(skb); protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, &frag_off); if (protoff < 0 || (frag_off & htons(~0x7)) != 0) { pr_debug("proto header not found\n"); - goto out; + return nf_conntrack_confirm(skb); } - /* adjust seqs for loopback traffic only in outgoing direction */ - if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && - !nf_is_loopback_packet(skb)) { - if (!nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) { - NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); - return NF_DROP; - } - } -out: - /* We've seen it coming out the other side: confirm it */ - return nf_conntrack_confirm(skb); + return nf_confirm(skb, protoff, ct, ctinfo); } static unsigned int ipv6_conntrack_in(void *priv, @@ -664,42 +398,6 @@ static unsigned int ipv6_conntrack_local(void *priv, return nf_conntrack_in(skb, state); } -static unsigned int ipv6_helper(void *priv, - struct sk_buff *skb, - const struct nf_hook_state *state) -{ - struct nf_conn *ct; - const struct nf_conn_help *help; - const struct nf_conntrack_helper *helper; - enum ip_conntrack_info ctinfo; - __be16 frag_off; - int protoff; - u8 nexthdr; - - /* This is where we call the helper: as the packet goes out. */ - ct = nf_ct_get(skb, &ctinfo); - if (!ct || ctinfo == IP_CT_RELATED_REPLY) - return NF_ACCEPT; - - help = nfct_help(ct); - if (!help) - return NF_ACCEPT; - /* rcu_read_lock()ed by nf_hook_thresh */ - helper = rcu_dereference(help->helper); - if (!helper) - return NF_ACCEPT; - - nexthdr = ipv6_hdr(skb)->nexthdr; - protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr, - &frag_off); - if (protoff < 0 || (frag_off & htons(~0x7)) != 0) { - pr_debug("proto header not found\n"); - return NF_ACCEPT; - } - - return helper->help(skb, protoff, ct, ctinfo); -} - static const struct nf_hook_ops ipv6_conntrack_ops[] = { { .hook = ipv6_conntrack_in, @@ -714,24 +412,12 @@ static const struct nf_hook_ops ipv6_conntrack_ops[] = { .priority = NF_IP6_PRI_CONNTRACK, }, { - .hook = ipv6_helper, - .pf = NFPROTO_IPV6, - .hooknum = NF_INET_POST_ROUTING, - .priority = NF_IP6_PRI_CONNTRACK_HELPER, - }, - { .hook = ipv6_confirm, .pf = NFPROTO_IPV6, .hooknum = NF_INET_POST_ROUTING, .priority = NF_IP6_PRI_LAST, }, { - .hook = ipv6_helper, - .pf = NFPROTO_IPV6, - .hooknum = NF_INET_LOCAL_IN, - .priority = NF_IP6_PRI_CONNTRACK_HELPER, - }, - { .hook = ipv6_confirm, .pf = NFPROTO_IPV6, .hooknum = NF_INET_LOCAL_IN, @@ -874,27 +560,9 @@ void nf_ct_netns_put(struct net *net, uint8_t nfproto) } EXPORT_SYMBOL_GPL(nf_ct_netns_put); -static const struct nf_conntrack_l4proto * const builtin_l4proto[] = { - &nf_conntrack_l4proto_tcp, - &nf_conntrack_l4proto_udp, - &nf_conntrack_l4proto_icmp, -#ifdef CONFIG_NF_CT_PROTO_DCCP - &nf_conntrack_l4proto_dccp, -#endif -#ifdef CONFIG_NF_CT_PROTO_SCTP - &nf_conntrack_l4proto_sctp, -#endif -#ifdef CONFIG_NF_CT_PROTO_UDPLITE - &nf_conntrack_l4proto_udplite, -#endif -#if IS_ENABLED(CONFIG_IPV6) - &nf_conntrack_l4proto_icmpv6, -#endif /* CONFIG_IPV6 */ -}; - int nf_conntrack_proto_init(void) { - int ret = 0, i; + int ret; ret = nf_register_sockopt(&so_getorigdst); if (ret < 0) @@ -906,18 +574,8 @@ int nf_conntrack_proto_init(void) goto cleanup_sockopt; #endif - for (i = 0; i < ARRAY_SIZE(nf_ct_protos); i++) - RCU_INIT_POINTER(nf_ct_protos[i], - &nf_conntrack_l4proto_generic); - - ret = nf_ct_l4proto_register(builtin_l4proto, - ARRAY_SIZE(builtin_l4proto)); - if (ret < 0) - goto cleanup_sockopt2; - return ret; -cleanup_sockopt2: - nf_unregister_sockopt(&so_getorigdst); + #if IS_ENABLED(CONFIG_IPV6) cleanup_sockopt: nf_unregister_sockopt(&so_getorigdst6); @@ -933,43 +591,33 @@ void nf_conntrack_proto_fini(void) #endif } -int nf_conntrack_proto_pernet_init(struct net *net) +void nf_conntrack_proto_pernet_init(struct net *net) { - int err; - struct nf_proto_net *pn = nf_ct_l4proto_net(net, - &nf_conntrack_l4proto_generic); - - err = nf_conntrack_l4proto_generic.init_net(net); - if (err < 0) - return err; - err = nf_ct_l4proto_register_sysctl(net, - pn); - if (err < 0) - return err; - - err = nf_ct_l4proto_pernet_register(net, builtin_l4proto, - ARRAY_SIZE(builtin_l4proto)); - if (err < 0) { - nf_ct_l4proto_unregister_sysctl(pn); - return err; - } - - pn->users++; - return 0; + nf_conntrack_generic_init_net(net); + nf_conntrack_udp_init_net(net); + nf_conntrack_tcp_init_net(net); + nf_conntrack_icmp_init_net(net); +#if IS_ENABLED(CONFIG_IPV6) + nf_conntrack_icmpv6_init_net(net); +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + nf_conntrack_dccp_init_net(net); +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + nf_conntrack_sctp_init_net(net); +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + nf_conntrack_gre_init_net(net); +#endif } void nf_conntrack_proto_pernet_fini(struct net *net) { - struct nf_proto_net *pn = nf_ct_l4proto_net(net, - &nf_conntrack_l4proto_generic); - - nf_ct_l4proto_pernet_unregister(net, builtin_l4proto, - ARRAY_SIZE(builtin_l4proto)); - pn->users--; - nf_ct_l4proto_unregister_sysctl(pn); +#ifdef CONFIG_NF_CT_PROTO_GRE + nf_ct_gre_keymap_flush(net); +#endif } - module_param_call(hashsize, nf_conntrack_set_hashsize, param_get_uint, &nf_conntrack_htable_size, 0600); diff --git a/net/netfilter/nf_conntrack_proto_dccp.c b/net/netfilter/nf_conntrack_proto_dccp.c index 023c1445bc39..6fca80587505 100644 --- a/net/netfilter/nf_conntrack_proto_dccp.c +++ b/net/netfilter/nf_conntrack_proto_dccp.c @@ -472,9 +472,10 @@ out_invalid: return true; } -static int dccp_packet(struct nf_conn *ct, struct sk_buff *skb, - unsigned int dataoff, enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_dccp_packet(struct nf_conn *ct, struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); struct dccp_hdr _dh, *dh; @@ -723,123 +724,28 @@ dccp_timeout_nla_policy[CTA_TIMEOUT_DCCP_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -/* template, data assigned later */ -static struct ctl_table dccp_sysctl_table[] = { - { - .procname = "nf_conntrack_dccp_timeout_request", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_timeout_respond", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_timeout_partopen", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_timeout_open", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_timeout_closereq", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_timeout_closing", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_timeout_timewait", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_dccp_loose", - .maxlen = sizeof(int), - .mode = 0644, - .proc_handler = proc_dointvec, - }, - { } -}; -#endif /* CONFIG_SYSCTL */ - -static int dccp_kmemdup_sysctl_table(struct net *net, struct nf_proto_net *pn, - struct nf_dccp_net *dn) -{ -#ifdef CONFIG_SYSCTL - if (pn->ctl_table) - return 0; - - pn->ctl_table = kmemdup(dccp_sysctl_table, - sizeof(dccp_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - - pn->ctl_table[0].data = &dn->dccp_timeout[CT_DCCP_REQUEST]; - pn->ctl_table[1].data = &dn->dccp_timeout[CT_DCCP_RESPOND]; - pn->ctl_table[2].data = &dn->dccp_timeout[CT_DCCP_PARTOPEN]; - pn->ctl_table[3].data = &dn->dccp_timeout[CT_DCCP_OPEN]; - pn->ctl_table[4].data = &dn->dccp_timeout[CT_DCCP_CLOSEREQ]; - pn->ctl_table[5].data = &dn->dccp_timeout[CT_DCCP_CLOSING]; - pn->ctl_table[6].data = &dn->dccp_timeout[CT_DCCP_TIMEWAIT]; - pn->ctl_table[7].data = &dn->dccp_loose; - - /* Don't export sysctls to unprivileged users */ - if (net->user_ns != &init_user_ns) - pn->ctl_table[0].procname = NULL; -#endif - return 0; -} - -static int dccp_init_net(struct net *net) +void nf_conntrack_dccp_init_net(struct net *net) { struct nf_dccp_net *dn = nf_dccp_pernet(net); - struct nf_proto_net *pn = &dn->pn; - - if (!pn->users) { - /* default values */ - dn->dccp_loose = 1; - dn->dccp_timeout[CT_DCCP_REQUEST] = 2 * DCCP_MSL; - dn->dccp_timeout[CT_DCCP_RESPOND] = 4 * DCCP_MSL; - dn->dccp_timeout[CT_DCCP_PARTOPEN] = 4 * DCCP_MSL; - dn->dccp_timeout[CT_DCCP_OPEN] = 12 * 3600 * HZ; - dn->dccp_timeout[CT_DCCP_CLOSEREQ] = 64 * HZ; - dn->dccp_timeout[CT_DCCP_CLOSING] = 64 * HZ; - dn->dccp_timeout[CT_DCCP_TIMEWAIT] = 2 * DCCP_MSL; - - /* timeouts[0] is unused, make it same as SYN_SENT so - * ->timeouts[0] contains 'new' timeout, like udp or icmp. - */ - dn->dccp_timeout[CT_DCCP_NONE] = dn->dccp_timeout[CT_DCCP_REQUEST]; - } - return dccp_kmemdup_sysctl_table(net, pn, dn); -} - -static struct nf_proto_net *dccp_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.dccp.pn; + /* default values */ + dn->dccp_loose = 1; + dn->dccp_timeout[CT_DCCP_REQUEST] = 2 * DCCP_MSL; + dn->dccp_timeout[CT_DCCP_RESPOND] = 4 * DCCP_MSL; + dn->dccp_timeout[CT_DCCP_PARTOPEN] = 4 * DCCP_MSL; + dn->dccp_timeout[CT_DCCP_OPEN] = 12 * 3600 * HZ; + dn->dccp_timeout[CT_DCCP_CLOSEREQ] = 64 * HZ; + dn->dccp_timeout[CT_DCCP_CLOSING] = 64 * HZ; + dn->dccp_timeout[CT_DCCP_TIMEWAIT] = 2 * DCCP_MSL; + + /* timeouts[0] is unused, make it same as SYN_SENT so + * ->timeouts[0] contains 'new' timeout, like udp or icmp. + */ + dn->dccp_timeout[CT_DCCP_NONE] = dn->dccp_timeout[CT_DCCP_REQUEST]; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp = { .l4proto = IPPROTO_DCCP, - .packet = dccp_packet, .can_early_drop = dccp_can_early_drop, #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = dccp_print_conntrack, @@ -862,6 +768,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp = { .nla_policy = dccp_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = dccp_init_net, - .get_net_proto = dccp_get_net_proto, }; diff --git a/net/netfilter/nf_conntrack_proto_generic.c b/net/netfilter/nf_conntrack_proto_generic.c index 5da19d5fbc76..0f526fafecae 100644 --- a/net/netfilter/nf_conntrack_proto_generic.c +++ b/net/netfilter/nf_conntrack_proto_generic.c @@ -15,50 +15,6 @@ static const unsigned int nf_ct_generic_timeout = 600*HZ; -static bool nf_generic_should_process(u8 proto) -{ - switch (proto) { -#ifdef CONFIG_NF_CT_PROTO_GRE_MODULE - case IPPROTO_GRE: - return false; -#endif - default: - return true; - } -} - -static bool generic_pkt_to_tuple(const struct sk_buff *skb, - unsigned int dataoff, - struct net *net, struct nf_conntrack_tuple *tuple) -{ - tuple->src.u.all = 0; - tuple->dst.u.all = 0; - - return true; -} - -/* Returns verdict for packet, or -1 for invalid. */ -static int generic_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) -{ - const unsigned int *timeout = nf_ct_timeout_lookup(ct); - - if (!nf_generic_should_process(nf_ct_protonum(ct))) { - pr_warn_once("conntrack: generic helper won't handle protocol %d. Please consider loading the specific helper module.\n", - nf_ct_protonum(ct)); - return -NF_ACCEPT; - } - - if (!timeout) - timeout = &nf_generic_pernet(nf_ct_net(ct))->timeout; - - nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); - return NF_ACCEPT; -} - #ifdef CONFIG_NF_CONNTRACK_TIMEOUT #include <linux/netfilter/nfnetlink.h> @@ -104,53 +60,16 @@ generic_timeout_nla_policy[CTA_TIMEOUT_GENERIC_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -static struct ctl_table generic_sysctl_table[] = { - { - .procname = "nf_conntrack_generic_timeout", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { } -}; -#endif /* CONFIG_SYSCTL */ - -static int generic_kmemdup_sysctl_table(struct nf_proto_net *pn, - struct nf_generic_net *gn) -{ -#ifdef CONFIG_SYSCTL - pn->ctl_table = kmemdup(generic_sysctl_table, - sizeof(generic_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - - pn->ctl_table[0].data = &gn->timeout; -#endif - return 0; -} - -static int generic_init_net(struct net *net) +void nf_conntrack_generic_init_net(struct net *net) { struct nf_generic_net *gn = nf_generic_pernet(net); - struct nf_proto_net *pn = &gn->pn; gn->timeout = nf_ct_generic_timeout; - - return generic_kmemdup_sysctl_table(pn, gn); -} - -static struct nf_proto_net *generic_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.generic.pn; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_generic = { .l4proto = 255, - .pkt_to_tuple = generic_pkt_to_tuple, - .packet = generic_packet, #ifdef CONFIG_NF_CONNTRACK_TIMEOUT .ctnl_timeout = { .nlattr_to_obj = generic_timeout_nlattr_to_obj, @@ -160,6 +79,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_generic = .nla_policy = generic_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = generic_init_net, - .get_net_proto = generic_get_net_proto, }; diff --git a/net/netfilter/nf_conntrack_proto_gre.c b/net/netfilter/nf_conntrack_proto_gre.c index 8899b51aad44..ee9ab10a32e4 100644 --- a/net/netfilter/nf_conntrack_proto_gre.c +++ b/net/netfilter/nf_conntrack_proto_gre.c @@ -48,24 +48,25 @@ static const unsigned int gre_timeouts[GRE_CT_MAX] = { [GRE_CT_REPLIED] = 180*HZ, }; -static unsigned int proto_gre_net_id __read_mostly; +/* used when expectation is added */ +static DEFINE_SPINLOCK(keymap_lock); -static inline struct netns_proto_gre *gre_pernet(struct net *net) +static inline struct nf_gre_net *gre_pernet(struct net *net) { - return net_generic(net, proto_gre_net_id); + return &net->ct.nf_ct_proto.gre; } -static void nf_ct_gre_keymap_flush(struct net *net) +void nf_ct_gre_keymap_flush(struct net *net) { - struct netns_proto_gre *net_gre = gre_pernet(net); + struct nf_gre_net *net_gre = gre_pernet(net); struct nf_ct_gre_keymap *km, *tmp; - write_lock_bh(&net_gre->keymap_lock); + spin_lock_bh(&keymap_lock); list_for_each_entry_safe(km, tmp, &net_gre->keymap_list, list) { - list_del(&km->list); - kfree(km); + list_del_rcu(&km->list); + kfree_rcu(km, rcu); } - write_unlock_bh(&net_gre->keymap_lock); + spin_unlock_bh(&keymap_lock); } static inline int gre_key_cmpfn(const struct nf_ct_gre_keymap *km, @@ -81,18 +82,16 @@ static inline int gre_key_cmpfn(const struct nf_ct_gre_keymap *km, /* look up the source key for a given tuple */ static __be16 gre_keymap_lookup(struct net *net, struct nf_conntrack_tuple *t) { - struct netns_proto_gre *net_gre = gre_pernet(net); + struct nf_gre_net *net_gre = gre_pernet(net); struct nf_ct_gre_keymap *km; __be16 key = 0; - read_lock_bh(&net_gre->keymap_lock); - list_for_each_entry(km, &net_gre->keymap_list, list) { + list_for_each_entry_rcu(km, &net_gre->keymap_list, list) { if (gre_key_cmpfn(km, t)) { key = km->tuple.src.u.gre.key; break; } } - read_unlock_bh(&net_gre->keymap_lock); pr_debug("lookup src key 0x%x for ", key); nf_ct_dump_tuple(t); @@ -105,21 +104,17 @@ int nf_ct_gre_keymap_add(struct nf_conn *ct, enum ip_conntrack_dir dir, struct nf_conntrack_tuple *t) { struct net *net = nf_ct_net(ct); - struct netns_proto_gre *net_gre = gre_pernet(net); + struct nf_gre_net *net_gre = gre_pernet(net); struct nf_ct_pptp_master *ct_pptp_info = nfct_help_data(ct); struct nf_ct_gre_keymap **kmp, *km; kmp = &ct_pptp_info->keymap[dir]; if (*kmp) { /* check whether it's a retransmission */ - read_lock_bh(&net_gre->keymap_lock); - list_for_each_entry(km, &net_gre->keymap_list, list) { - if (gre_key_cmpfn(km, t) && km == *kmp) { - read_unlock_bh(&net_gre->keymap_lock); + list_for_each_entry_rcu(km, &net_gre->keymap_list, list) { + if (gre_key_cmpfn(km, t) && km == *kmp) return 0; - } } - read_unlock_bh(&net_gre->keymap_lock); pr_debug("trying to override keymap_%s for ct %p\n", dir == IP_CT_DIR_REPLY ? "reply" : "orig", ct); return -EEXIST; @@ -134,9 +129,9 @@ int nf_ct_gre_keymap_add(struct nf_conn *ct, enum ip_conntrack_dir dir, pr_debug("adding new entry %p: ", km); nf_ct_dump_tuple(&km->tuple); - write_lock_bh(&net_gre->keymap_lock); + spin_lock_bh(&keymap_lock); list_add_tail(&km->list, &net_gre->keymap_list); - write_unlock_bh(&net_gre->keymap_lock); + spin_unlock_bh(&keymap_lock); return 0; } @@ -145,32 +140,30 @@ EXPORT_SYMBOL_GPL(nf_ct_gre_keymap_add); /* destroy the keymap entries associated with specified master ct */ void nf_ct_gre_keymap_destroy(struct nf_conn *ct) { - struct net *net = nf_ct_net(ct); - struct netns_proto_gre *net_gre = gre_pernet(net); struct nf_ct_pptp_master *ct_pptp_info = nfct_help_data(ct); enum ip_conntrack_dir dir; pr_debug("entering for ct %p\n", ct); - write_lock_bh(&net_gre->keymap_lock); + spin_lock_bh(&keymap_lock); for (dir = IP_CT_DIR_ORIGINAL; dir < IP_CT_DIR_MAX; dir++) { if (ct_pptp_info->keymap[dir]) { pr_debug("removing %p from list\n", ct_pptp_info->keymap[dir]); - list_del(&ct_pptp_info->keymap[dir]->list); - kfree(ct_pptp_info->keymap[dir]); + list_del_rcu(&ct_pptp_info->keymap[dir]->list); + kfree_rcu(ct_pptp_info->keymap[dir], rcu); ct_pptp_info->keymap[dir] = NULL; } } - write_unlock_bh(&net_gre->keymap_lock); + spin_unlock_bh(&keymap_lock); } EXPORT_SYMBOL_GPL(nf_ct_gre_keymap_destroy); /* PUBLIC CONNTRACK PROTO HELPER FUNCTIONS */ /* gre hdr info to tuple */ -static bool gre_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff, - struct net *net, struct nf_conntrack_tuple *tuple) +bool gre_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff, + struct net *net, struct nf_conntrack_tuple *tuple) { const struct pptp_gre_header *pgrehdr; struct pptp_gre_header _pgrehdr; @@ -216,15 +209,15 @@ static void gre_print_conntrack(struct seq_file *s, struct nf_conn *ct) static unsigned int *gre_get_timeouts(struct net *net) { - return gre_pernet(net)->gre_timeouts; + return gre_pernet(net)->timeouts; } /* Returns verdict for packet, and may modify conntrack */ -static int gre_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_gre_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { if (state->pf != NFPROTO_IPV4) return -NF_ACCEPT; @@ -256,19 +249,6 @@ static int gre_packet(struct nf_conn *ct, return NF_ACCEPT; } -/* Called when a conntrack entry has already been removed from the hashes - * and is about to be deleted from memory */ -static void gre_destroy(struct nf_conn *ct) -{ - struct nf_conn *master = ct->master; - pr_debug(" entering\n"); - - if (!master) - pr_debug("no master !?!\n"); - else - nf_ct_gre_keymap_destroy(master); -} - #ifdef CONFIG_NF_CONNTRACK_TIMEOUT #include <linux/netfilter/nfnetlink.h> @@ -278,13 +258,13 @@ static int gre_timeout_nlattr_to_obj(struct nlattr *tb[], struct net *net, void *data) { unsigned int *timeouts = data; - struct netns_proto_gre *net_gre = gre_pernet(net); + struct nf_gre_net *net_gre = gre_pernet(net); if (!timeouts) timeouts = gre_get_timeouts(net); /* set default timeouts for GRE. */ - timeouts[GRE_CT_UNREPLIED] = net_gre->gre_timeouts[GRE_CT_UNREPLIED]; - timeouts[GRE_CT_REPLIED] = net_gre->gre_timeouts[GRE_CT_REPLIED]; + timeouts[GRE_CT_UNREPLIED] = net_gre->timeouts[GRE_CT_UNREPLIED]; + timeouts[GRE_CT_REPLIED] = net_gre->timeouts[GRE_CT_REPLIED]; if (tb[CTA_TIMEOUT_GRE_UNREPLIED]) { timeouts[GRE_CT_UNREPLIED] = @@ -320,69 +300,22 @@ gre_timeout_nla_policy[CTA_TIMEOUT_GRE_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -static struct ctl_table gre_sysctl_table[] = { - { - .procname = "nf_conntrack_gre_timeout", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_gre_timeout_stream", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - {} -}; -#endif - -static int gre_kmemdup_sysctl_table(struct net *net, struct nf_proto_net *nf, - struct netns_proto_gre *net_gre) -{ -#ifdef CONFIG_SYSCTL - int i; - - if (nf->ctl_table) - return 0; - - nf->ctl_table = kmemdup(gre_sysctl_table, - sizeof(gre_sysctl_table), - GFP_KERNEL); - if (!nf->ctl_table) - return -ENOMEM; - - for (i = 0; i < GRE_CT_MAX; i++) - nf->ctl_table[i].data = &net_gre->gre_timeouts[i]; -#endif - return 0; -} - -static int gre_init_net(struct net *net) +void nf_conntrack_gre_init_net(struct net *net) { - struct netns_proto_gre *net_gre = gre_pernet(net); - struct nf_proto_net *nf = &net_gre->nf; + struct nf_gre_net *net_gre = gre_pernet(net); int i; - rwlock_init(&net_gre->keymap_lock); INIT_LIST_HEAD(&net_gre->keymap_list); for (i = 0; i < GRE_CT_MAX; i++) - net_gre->gre_timeouts[i] = gre_timeouts[i]; - - return gre_kmemdup_sysctl_table(net, nf, net_gre); + net_gre->timeouts[i] = gre_timeouts[i]; } /* protocol helper struct */ -static const struct nf_conntrack_l4proto nf_conntrack_l4proto_gre4 = { +const struct nf_conntrack_l4proto nf_conntrack_l4proto_gre = { .l4proto = IPPROTO_GRE, - .pkt_to_tuple = gre_pkt_to_tuple, #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = gre_print_conntrack, #endif - .packet = gre_packet, - .destroy = gre_destroy, - .me = THIS_MODULE, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, @@ -398,61 +331,4 @@ static const struct nf_conntrack_l4proto nf_conntrack_l4proto_gre4 = { .nla_policy = gre_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .net_id = &proto_gre_net_id, - .init_net = gre_init_net, }; - -static int proto_gre_net_init(struct net *net) -{ - int ret = 0; - - ret = nf_ct_l4proto_pernet_register_one(net, - &nf_conntrack_l4proto_gre4); - if (ret < 0) - pr_err("nf_conntrack_gre4: pernet registration failed.\n"); - return ret; -} - -static void proto_gre_net_exit(struct net *net) -{ - nf_ct_l4proto_pernet_unregister_one(net, &nf_conntrack_l4proto_gre4); - nf_ct_gre_keymap_flush(net); -} - -static struct pernet_operations proto_gre_net_ops = { - .init = proto_gre_net_init, - .exit = proto_gre_net_exit, - .id = &proto_gre_net_id, - .size = sizeof(struct netns_proto_gre), -}; - -static int __init nf_ct_proto_gre_init(void) -{ - int ret; - - BUILD_BUG_ON(offsetof(struct netns_proto_gre, nf) != 0); - - ret = register_pernet_subsys(&proto_gre_net_ops); - if (ret < 0) - goto out_pernet; - ret = nf_ct_l4proto_register_one(&nf_conntrack_l4proto_gre4); - if (ret < 0) - goto out_gre4; - - return 0; -out_gre4: - unregister_pernet_subsys(&proto_gre_net_ops); -out_pernet: - return ret; -} - -static void __exit nf_ct_proto_gre_fini(void) -{ - nf_ct_l4proto_unregister_one(&nf_conntrack_l4proto_gre4); - unregister_pernet_subsys(&proto_gre_net_ops); -} - -module_init(nf_ct_proto_gre_init); -module_exit(nf_ct_proto_gre_fini); - -MODULE_LICENSE("GPL"); diff --git a/net/netfilter/nf_conntrack_proto_icmp.c b/net/netfilter/nf_conntrack_proto_icmp.c index de64d8a5fdfd..7df477996b16 100644 --- a/net/netfilter/nf_conntrack_proto_icmp.c +++ b/net/netfilter/nf_conntrack_proto_icmp.c @@ -25,8 +25,8 @@ static const unsigned int nf_ct_icmp_timeout = 30*HZ; -static bool icmp_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff, - struct net *net, struct nf_conntrack_tuple *tuple) +bool icmp_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff, + struct net *net, struct nf_conntrack_tuple *tuple) { const struct icmphdr *hp; struct icmphdr _hdr; @@ -54,8 +54,8 @@ static const u_int8_t invmap[] = { [ICMP_ADDRESSREPLY] = ICMP_ADDRESS + 1 }; -static bool icmp_invert_tuple(struct nf_conntrack_tuple *tuple, - const struct nf_conntrack_tuple *orig) +bool nf_conntrack_invert_icmp_tuple(struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_tuple *orig) { if (orig->dst.u.icmp.type >= sizeof(invmap) || !invmap[orig->dst.u.icmp.type]) @@ -68,11 +68,10 @@ static bool icmp_invert_tuple(struct nf_conntrack_tuple *tuple, } /* Returns verdict for packet, or -1 for invalid. */ -static int icmp_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_icmp_packet(struct nf_conn *ct, + struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { /* Do not immediately delete the connection after the first successful reply to avoid excessive conntrackd traffic @@ -110,7 +109,6 @@ icmp_error_message(struct nf_conn *tmpl, struct sk_buff *skb, const struct nf_hook_state *state) { struct nf_conntrack_tuple innertuple, origtuple; - const struct nf_conntrack_l4proto *innerproto; const struct nf_conntrack_tuple_hash *h; const struct nf_conntrack_zone *zone; enum ip_conntrack_info ctinfo; @@ -128,12 +126,9 @@ icmp_error_message(struct nf_conn *tmpl, struct sk_buff *skb, return -NF_ACCEPT; } - /* rcu_read_lock()ed by nf_hook_thresh */ - innerproto = __nf_ct_l4proto_find(origtuple.dst.protonum); - /* Ordinarily, we'd expect the inverted tupleproto, but it's been preserved inside the ICMP. */ - if (!nf_ct_invert_tuple(&innertuple, &origtuple, innerproto)) { + if (!nf_ct_invert_tuple(&innertuple, &origtuple)) { pr_debug("icmp_error_message: no match\n"); return -NF_ACCEPT; } @@ -303,56 +298,16 @@ icmp_timeout_nla_policy[CTA_TIMEOUT_ICMP_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -static struct ctl_table icmp_sysctl_table[] = { - { - .procname = "nf_conntrack_icmp_timeout", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { } -}; -#endif /* CONFIG_SYSCTL */ - -static int icmp_kmemdup_sysctl_table(struct nf_proto_net *pn, - struct nf_icmp_net *in) -{ -#ifdef CONFIG_SYSCTL - pn->ctl_table = kmemdup(icmp_sysctl_table, - sizeof(icmp_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - - pn->ctl_table[0].data = &in->timeout; -#endif - return 0; -} - -static int icmp_init_net(struct net *net) +void nf_conntrack_icmp_init_net(struct net *net) { struct nf_icmp_net *in = nf_icmp_pernet(net); - struct nf_proto_net *pn = &in->pn; in->timeout = nf_ct_icmp_timeout; - - return icmp_kmemdup_sysctl_table(pn, in); -} - -static struct nf_proto_net *icmp_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.icmp.pn; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmp = { .l4proto = IPPROTO_ICMP, - .pkt_to_tuple = icmp_pkt_to_tuple, - .invert_tuple = icmp_invert_tuple, - .packet = icmp_packet, - .destroy = NULL, - .me = NULL, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = icmp_tuple_to_nlattr, .nlattr_tuple_size = icmp_nlattr_tuple_size, @@ -368,6 +323,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmp = .nla_policy = icmp_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = icmp_init_net, - .get_net_proto = icmp_get_net_proto, }; diff --git a/net/netfilter/nf_conntrack_proto_icmpv6.c b/net/netfilter/nf_conntrack_proto_icmpv6.c index a15eefb8e317..bec4a3211658 100644 --- a/net/netfilter/nf_conntrack_proto_icmpv6.c +++ b/net/netfilter/nf_conntrack_proto_icmpv6.c @@ -30,10 +30,10 @@ static const unsigned int nf_ct_icmpv6_timeout = 30*HZ; -static bool icmpv6_pkt_to_tuple(const struct sk_buff *skb, - unsigned int dataoff, - struct net *net, - struct nf_conntrack_tuple *tuple) +bool icmpv6_pkt_to_tuple(const struct sk_buff *skb, + unsigned int dataoff, + struct net *net, + struct nf_conntrack_tuple *tuple) { const struct icmp6hdr *hp; struct icmp6hdr _hdr; @@ -67,8 +67,8 @@ static const u_int8_t noct_valid_new[] = { [ICMPV6_MLD2_REPORT - 130] = 1 }; -static bool icmpv6_invert_tuple(struct nf_conntrack_tuple *tuple, - const struct nf_conntrack_tuple *orig) +bool nf_conntrack_invert_icmpv6_tuple(struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_tuple *orig) { int type = orig->dst.u.icmp.type - 128; if (type < 0 || type >= sizeof(invmap) || !invmap[type]) @@ -86,11 +86,10 @@ static unsigned int *icmpv6_get_timeouts(struct net *net) } /* Returns verdict for packet, or -1 for invalid. */ -static int icmpv6_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_icmpv6_packet(struct nf_conn *ct, + struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { unsigned int *timeout = nf_ct_timeout_lookup(ct); static const u8 valid_new[] = { @@ -131,7 +130,6 @@ icmpv6_error_message(struct net *net, struct nf_conn *tmpl, { struct nf_conntrack_tuple intuple, origtuple; const struct nf_conntrack_tuple_hash *h; - const struct nf_conntrack_l4proto *inproto; enum ip_conntrack_info ctinfo; struct nf_conntrack_zone tmp; @@ -147,12 +145,9 @@ icmpv6_error_message(struct net *net, struct nf_conn *tmpl, return -NF_ACCEPT; } - /* rcu_read_lock()ed by nf_hook_thresh */ - inproto = __nf_ct_l4proto_find(origtuple.dst.protonum); - /* Ordinarily, we'd expect the inverted tupleproto, but it's been preserved inside the ICMP. */ - if (!nf_ct_invert_tuple(&intuple, &origtuple, inproto)) { + if (!nf_ct_invert_tuple(&intuple, &origtuple)) { pr_debug("icmpv6_error: Can't invert tuple\n"); return -NF_ACCEPT; } @@ -314,54 +309,16 @@ icmpv6_timeout_nla_policy[CTA_TIMEOUT_ICMPV6_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -static struct ctl_table icmpv6_sysctl_table[] = { - { - .procname = "nf_conntrack_icmpv6_timeout", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { } -}; -#endif /* CONFIG_SYSCTL */ - -static int icmpv6_kmemdup_sysctl_table(struct nf_proto_net *pn, - struct nf_icmp_net *in) -{ -#ifdef CONFIG_SYSCTL - pn->ctl_table = kmemdup(icmpv6_sysctl_table, - sizeof(icmpv6_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - - pn->ctl_table[0].data = &in->timeout; -#endif - return 0; -} - -static int icmpv6_init_net(struct net *net) +void nf_conntrack_icmpv6_init_net(struct net *net) { struct nf_icmp_net *in = nf_icmpv6_pernet(net); - struct nf_proto_net *pn = &in->pn; in->timeout = nf_ct_icmpv6_timeout; - - return icmpv6_kmemdup_sysctl_table(pn, in); -} - -static struct nf_proto_net *icmpv6_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.icmpv6.pn; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmpv6 = { .l4proto = IPPROTO_ICMPV6, - .pkt_to_tuple = icmpv6_pkt_to_tuple, - .invert_tuple = icmpv6_invert_tuple, - .packet = icmpv6_packet, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = icmpv6_tuple_to_nlattr, .nlattr_tuple_size = icmpv6_nlattr_tuple_size, @@ -377,6 +334,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmpv6 = .nla_policy = icmpv6_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = icmpv6_init_net, - .get_net_proto = icmpv6_get_net_proto, }; diff --git a/net/netfilter/nf_conntrack_proto_sctp.c b/net/netfilter/nf_conntrack_proto_sctp.c index d53e3e78f605..a7818101ad80 100644 --- a/net/netfilter/nf_conntrack_proto_sctp.c +++ b/net/netfilter/nf_conntrack_proto_sctp.c @@ -357,11 +357,11 @@ out_invalid: } /* Returns verdict for packet, or -NF_ACCEPT for invalid. */ -static int sctp_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_sctp_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { enum sctp_conntrack new_state, old_state; enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); @@ -642,116 +642,18 @@ sctp_timeout_nla_policy[CTA_TIMEOUT_SCTP_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - -#ifdef CONFIG_SYSCTL -static struct ctl_table sctp_sysctl_table[] = { - { - .procname = "nf_conntrack_sctp_timeout_closed", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_cookie_wait", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_cookie_echoed", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_established", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_shutdown_sent", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_shutdown_recd", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_shutdown_ack_sent", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_heartbeat_sent", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_sctp_timeout_heartbeat_acked", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { } -}; -#endif - -static int sctp_kmemdup_sysctl_table(struct nf_proto_net *pn, - struct nf_sctp_net *sn) -{ -#ifdef CONFIG_SYSCTL - if (pn->ctl_table) - return 0; - - pn->ctl_table = kmemdup(sctp_sysctl_table, - sizeof(sctp_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - - pn->ctl_table[0].data = &sn->timeouts[SCTP_CONNTRACK_CLOSED]; - pn->ctl_table[1].data = &sn->timeouts[SCTP_CONNTRACK_COOKIE_WAIT]; - pn->ctl_table[2].data = &sn->timeouts[SCTP_CONNTRACK_COOKIE_ECHOED]; - pn->ctl_table[3].data = &sn->timeouts[SCTP_CONNTRACK_ESTABLISHED]; - pn->ctl_table[4].data = &sn->timeouts[SCTP_CONNTRACK_SHUTDOWN_SENT]; - pn->ctl_table[5].data = &sn->timeouts[SCTP_CONNTRACK_SHUTDOWN_RECD]; - pn->ctl_table[6].data = &sn->timeouts[SCTP_CONNTRACK_SHUTDOWN_ACK_SENT]; - pn->ctl_table[7].data = &sn->timeouts[SCTP_CONNTRACK_HEARTBEAT_SENT]; - pn->ctl_table[8].data = &sn->timeouts[SCTP_CONNTRACK_HEARTBEAT_ACKED]; -#endif - return 0; -} - -static int sctp_init_net(struct net *net) +void nf_conntrack_sctp_init_net(struct net *net) { struct nf_sctp_net *sn = nf_sctp_pernet(net); - struct nf_proto_net *pn = &sn->pn; - - if (!pn->users) { - int i; - - for (i = 0; i < SCTP_CONNTRACK_MAX; i++) - sn->timeouts[i] = sctp_timeouts[i]; - - /* timeouts[0] is unused, init it so ->timeouts[0] contains - * 'new' timeout, like udp or icmp. - */ - sn->timeouts[0] = sctp_timeouts[SCTP_CONNTRACK_CLOSED]; - } + int i; - return sctp_kmemdup_sysctl_table(pn, sn); -} + for (i = 0; i < SCTP_CONNTRACK_MAX; i++) + sn->timeouts[i] = sctp_timeouts[i]; -static struct nf_proto_net *sctp_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.sctp.pn; + /* timeouts[0] is unused, init it so ->timeouts[0] contains + * 'new' timeout, like udp or icmp. + */ + sn->timeouts[0] = sctp_timeouts[SCTP_CONNTRACK_CLOSED]; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp = { @@ -759,9 +661,7 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp = { #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = sctp_print_conntrack, #endif - .packet = sctp_packet, .can_early_drop = sctp_can_early_drop, - .me = THIS_MODULE, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .nlattr_size = SCTP_NLATTR_SIZE, .to_nlattr = sctp_to_nlattr, @@ -780,6 +680,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp = { .nla_policy = sctp_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = sctp_init_net, - .get_net_proto = sctp_get_net_proto, }; diff --git a/net/netfilter/nf_conntrack_proto_tcp.c b/net/netfilter/nf_conntrack_proto_tcp.c index 4dcbd51a8e97..a06875a466a4 100644 --- a/net/netfilter/nf_conntrack_proto_tcp.c +++ b/net/netfilter/nf_conntrack_proto_tcp.c @@ -828,12 +828,18 @@ static noinline bool tcp_new(struct nf_conn *ct, const struct sk_buff *skb, return true; } +static bool nf_conntrack_tcp_established(const struct nf_conn *ct) +{ + return ct->proto.tcp.state == TCP_CONNTRACK_ESTABLISHED && + test_bit(IPS_ASSURED_BIT, &ct->status); +} + /* Returns verdict for packet, or -1 for invalid. */ -static int tcp_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_tcp_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { struct net *net = nf_ct_net(ct); struct nf_tcp_net *tn = nf_tcp_pernet(net); @@ -1030,16 +1036,38 @@ static int tcp_packet(struct nf_conn *ct, new_state = TCP_CONNTRACK_ESTABLISHED; break; case TCP_CONNTRACK_CLOSE: - if (index == TCP_RST_SET - && (ct->proto.tcp.seen[!dir].flags & IP_CT_TCP_FLAG_MAXACK_SET) - && before(ntohl(th->seq), ct->proto.tcp.seen[!dir].td_maxack)) { - /* Invalid RST */ - spin_unlock_bh(&ct->lock); - nf_ct_l4proto_log_invalid(skb, ct, "invalid rst"); - return -NF_ACCEPT; + if (index != TCP_RST_SET) + break; + + if (ct->proto.tcp.seen[!dir].flags & IP_CT_TCP_FLAG_MAXACK_SET) { + u32 seq = ntohl(th->seq); + + if (before(seq, ct->proto.tcp.seen[!dir].td_maxack)) { + /* Invalid RST */ + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, "invalid rst"); + return -NF_ACCEPT; + } + + if (!nf_conntrack_tcp_established(ct) || + seq == ct->proto.tcp.seen[!dir].td_maxack) + break; + + /* Check if rst is part of train, such as + * foo:80 > bar:4379: P, 235946583:235946602(19) ack 42 + * foo:80 > bar:4379: R, 235946602:235946602(0) ack 42 + */ + if (ct->proto.tcp.last_index == TCP_ACK_SET && + ct->proto.tcp.last_dir == dir && + seq == ct->proto.tcp.last_end) + break; + + /* ... RST sequence number doesn't match exactly, keep + * established state to allow a possible challenge ACK. + */ + new_state = old_state; } - if (index == TCP_RST_SET - && ((test_bit(IPS_SEEN_REPLY_BIT, &ct->status) + if (((test_bit(IPS_SEEN_REPLY_BIT, &ct->status) && ct->proto.tcp.last_index == TCP_SYN_SET) || (!test_bit(IPS_ASSURED_BIT, &ct->status) && ct->proto.tcp.last_index == TCP_ACK_SET)) @@ -1055,7 +1083,7 @@ static int tcp_packet(struct nf_conn *ct, * segments we ignored. */ goto in_window; } - /* Just fall through */ + break; default: /* Keep compilers happy. */ break; @@ -1090,6 +1118,8 @@ static int tcp_packet(struct nf_conn *ct, if (ct->proto.tcp.retrans >= tn->tcp_max_retrans && timeouts[new_state] > timeouts[TCP_CONNTRACK_RETRANS]) timeout = timeouts[TCP_CONNTRACK_RETRANS]; + else if (unlikely(index == TCP_RST_SET)) + timeout = timeouts[TCP_CONNTRACK_CLOSE]; else if ((ct->proto.tcp.seen[0].flags | ct->proto.tcp.seen[1].flags) & IP_CT_TCP_FLAG_DATA_UNACKNOWLEDGED && timeouts[new_state] > timeouts[TCP_CONNTRACK_UNACK]) @@ -1387,146 +1417,21 @@ static const struct nla_policy tcp_timeout_nla_policy[CTA_TIMEOUT_TCP_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -static struct ctl_table tcp_sysctl_table[] = { - { - .procname = "nf_conntrack_tcp_timeout_syn_sent", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_syn_recv", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_established", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_fin_wait", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_close_wait", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_last_ack", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_time_wait", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_close", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_max_retrans", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_timeout_unacknowledged", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_tcp_loose", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec, - }, - { - .procname = "nf_conntrack_tcp_be_liberal", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec, - }, - { - .procname = "nf_conntrack_tcp_max_retrans", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec, - }, - { } -}; -#endif /* CONFIG_SYSCTL */ - -static int tcp_kmemdup_sysctl_table(struct nf_proto_net *pn, - struct nf_tcp_net *tn) -{ -#ifdef CONFIG_SYSCTL - if (pn->ctl_table) - return 0; - - pn->ctl_table = kmemdup(tcp_sysctl_table, - sizeof(tcp_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - - pn->ctl_table[0].data = &tn->timeouts[TCP_CONNTRACK_SYN_SENT]; - pn->ctl_table[1].data = &tn->timeouts[TCP_CONNTRACK_SYN_RECV]; - pn->ctl_table[2].data = &tn->timeouts[TCP_CONNTRACK_ESTABLISHED]; - pn->ctl_table[3].data = &tn->timeouts[TCP_CONNTRACK_FIN_WAIT]; - pn->ctl_table[4].data = &tn->timeouts[TCP_CONNTRACK_CLOSE_WAIT]; - pn->ctl_table[5].data = &tn->timeouts[TCP_CONNTRACK_LAST_ACK]; - pn->ctl_table[6].data = &tn->timeouts[TCP_CONNTRACK_TIME_WAIT]; - pn->ctl_table[7].data = &tn->timeouts[TCP_CONNTRACK_CLOSE]; - pn->ctl_table[8].data = &tn->timeouts[TCP_CONNTRACK_RETRANS]; - pn->ctl_table[9].data = &tn->timeouts[TCP_CONNTRACK_UNACK]; - pn->ctl_table[10].data = &tn->tcp_loose; - pn->ctl_table[11].data = &tn->tcp_be_liberal; - pn->ctl_table[12].data = &tn->tcp_max_retrans; -#endif - return 0; -} - -static int tcp_init_net(struct net *net) +void nf_conntrack_tcp_init_net(struct net *net) { struct nf_tcp_net *tn = nf_tcp_pernet(net); - struct nf_proto_net *pn = &tn->pn; - - if (!pn->users) { - int i; - - for (i = 0; i < TCP_CONNTRACK_TIMEOUT_MAX; i++) - tn->timeouts[i] = tcp_timeouts[i]; + int i; - /* timeouts[0] is unused, make it same as SYN_SENT so - * ->timeouts[0] contains 'new' timeout, like udp or icmp. - */ - tn->timeouts[0] = tcp_timeouts[TCP_CONNTRACK_SYN_SENT]; - tn->tcp_loose = nf_ct_tcp_loose; - tn->tcp_be_liberal = nf_ct_tcp_be_liberal; - tn->tcp_max_retrans = nf_ct_tcp_max_retrans; - } + for (i = 0; i < TCP_CONNTRACK_TIMEOUT_MAX; i++) + tn->timeouts[i] = tcp_timeouts[i]; - return tcp_kmemdup_sysctl_table(pn, tn); -} - -static struct nf_proto_net *tcp_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.tcp.pn; + /* timeouts[0] is unused, make it same as SYN_SENT so + * ->timeouts[0] contains 'new' timeout, like udp or icmp. + */ + tn->timeouts[0] = tcp_timeouts[TCP_CONNTRACK_SYN_SENT]; + tn->tcp_loose = nf_ct_tcp_loose; + tn->tcp_be_liberal = nf_ct_tcp_be_liberal; + tn->tcp_max_retrans = nf_ct_tcp_max_retrans; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp = @@ -1535,7 +1440,6 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp = #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = tcp_print_conntrack, #endif - .packet = tcp_packet, .can_early_drop = tcp_can_early_drop, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .to_nlattr = tcp_to_nlattr, @@ -1556,6 +1460,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp = .nla_policy = tcp_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = tcp_init_net, - .get_net_proto = tcp_get_net_proto, }; diff --git a/net/netfilter/nf_conntrack_proto_udp.c b/net/netfilter/nf_conntrack_proto_udp.c index b4f5d5e82031..951366dfbec3 100644 --- a/net/netfilter/nf_conntrack_proto_udp.c +++ b/net/netfilter/nf_conntrack_proto_udp.c @@ -85,11 +85,11 @@ static bool udp_error(struct sk_buff *skb, } /* Returns verdict for packet, and may modify conntracktype */ -static int udp_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_udp_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { unsigned int *timeouts; @@ -177,11 +177,11 @@ static bool udplite_error(struct sk_buff *skb, } /* Returns verdict for packet, and may modify conntracktype */ -static int udplite_packet(struct nf_conn *ct, - struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo, - const struct nf_hook_state *state) +int nf_conntrack_udplite_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { unsigned int *timeouts; @@ -260,66 +260,19 @@ udp_timeout_nla_policy[CTA_TIMEOUT_UDP_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -#ifdef CONFIG_SYSCTL -static struct ctl_table udp_sysctl_table[] = { - { - .procname = "nf_conntrack_udp_timeout", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { - .procname = "nf_conntrack_udp_timeout_stream", - .maxlen = sizeof(unsigned int), - .mode = 0644, - .proc_handler = proc_dointvec_jiffies, - }, - { } -}; -#endif /* CONFIG_SYSCTL */ - -static int udp_kmemdup_sysctl_table(struct nf_proto_net *pn, - struct nf_udp_net *un) -{ -#ifdef CONFIG_SYSCTL - if (pn->ctl_table) - return 0; - pn->ctl_table = kmemdup(udp_sysctl_table, - sizeof(udp_sysctl_table), - GFP_KERNEL); - if (!pn->ctl_table) - return -ENOMEM; - pn->ctl_table[0].data = &un->timeouts[UDP_CT_UNREPLIED]; - pn->ctl_table[1].data = &un->timeouts[UDP_CT_REPLIED]; -#endif - return 0; -} - -static int udp_init_net(struct net *net) +void nf_conntrack_udp_init_net(struct net *net) { struct nf_udp_net *un = nf_udp_pernet(net); - struct nf_proto_net *pn = &un->pn; + int i; - if (!pn->users) { - int i; - - for (i = 0; i < UDP_CT_MAX; i++) - un->timeouts[i] = udp_timeouts[i]; - } - - return udp_kmemdup_sysctl_table(pn, un); -} - -static struct nf_proto_net *udp_get_net_proto(struct net *net) -{ - return &net->ct.nf_ct_proto.udp.pn; + for (i = 0; i < UDP_CT_MAX; i++) + un->timeouts[i] = udp_timeouts[i]; } const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp = { .l4proto = IPPROTO_UDP, .allow_clash = true, - .packet = udp_packet, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, @@ -335,8 +288,6 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp = .nla_policy = udp_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = udp_init_net, - .get_net_proto = udp_get_net_proto, }; #ifdef CONFIG_NF_CT_PROTO_UDPLITE @@ -344,7 +295,6 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite = { .l4proto = IPPROTO_UDPLITE, .allow_clash = true, - .packet = udplite_packet, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, @@ -360,7 +310,5 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite = .nla_policy = udp_timeout_nla_policy, }, #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = udp_init_net, - .get_net_proto = udp_get_net_proto, }; #endif diff --git a/net/netfilter/nf_conntrack_sip.c b/net/netfilter/nf_conntrack_sip.c index c8d2b6688a2a..f067c6b50857 100644 --- a/net/netfilter/nf_conntrack_sip.c +++ b/net/netfilter/nf_conntrack_sip.c @@ -21,6 +21,8 @@ #include <linux/tcp.h> #include <linux/netfilter.h> +#include <net/route.h> +#include <net/ip6_route.h> #include <net/netfilter/nf_conntrack.h> #include <net/netfilter/nf_conntrack_core.h> #include <net/netfilter/nf_conntrack_expect.h> @@ -54,6 +56,11 @@ module_param(sip_direct_media, int, 0600); MODULE_PARM_DESC(sip_direct_media, "Expect Media streams between signalling " "endpoints only (default 1)"); +static int sip_external_media __read_mostly = 0; +module_param(sip_external_media, int, 0600); +MODULE_PARM_DESC(sip_external_media, "Expect Media streams between external " + "endpoints (default 0)"); + const struct nf_nat_sip_hooks *nf_nat_sip_hooks; EXPORT_SYMBOL_GPL(nf_nat_sip_hooks); @@ -861,6 +868,41 @@ static int set_expected_rtp_rtcp(struct sk_buff *skb, unsigned int protoff, if (!nf_inet_addr_cmp(daddr, &ct->tuplehash[dir].tuple.src.u3)) return NF_ACCEPT; saddr = &ct->tuplehash[!dir].tuple.src.u3; + } else if (sip_external_media) { + struct net_device *dev = skb_dst(skb)->dev; + struct net *net = dev_net(dev); + struct rtable *rt; + struct flowi4 fl4 = {}; +#if IS_ENABLED(CONFIG_IPV6) + struct flowi6 fl6 = {}; +#endif + struct dst_entry *dst = NULL; + + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + fl4.daddr = daddr->ip; + rt = ip_route_output_key(net, &fl4); + if (!IS_ERR(rt)) + dst = &rt->dst; + break; + +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: + fl6.daddr = daddr->in6; + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + dst_release(dst); + dst = NULL; + } + break; +#endif + } + + /* Don't predict any conntracks when media endpoint is reachable + * through the same interface as the signalling peer. + */ + if (dst && dst->dev == dev) + return NF_ACCEPT; } /* We need to check whether the registration exists before attempting diff --git a/net/netfilter/nf_conntrack_standalone.c b/net/netfilter/nf_conntrack_standalone.c index b6177fd73304..c2ae14c720b4 100644 --- a/net/netfilter/nf_conntrack_standalone.c +++ b/net/netfilter/nf_conntrack_standalone.c @@ -24,6 +24,10 @@ #include <net/netfilter/nf_conntrack_timestamp.h> #include <linux/rculist_nulls.h> +static bool enable_hooks __read_mostly; +MODULE_PARM_DESC(enable_hooks, "Always enable conntrack hooks"); +module_param(enable_hooks, bool, 0000); + unsigned int nf_conntrack_net_id __read_mostly; #ifdef CONFIG_NF_CONNTRACK_PROCFS @@ -310,8 +314,7 @@ static int ct_seq_show(struct seq_file *s, void *v) if (!net_eq(nf_ct_net(ct), net)) goto release; - l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); - WARN_ON(!l4proto); + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); ret = -ENOSPC; seq_printf(s, "%-8s %u %-8s %u ", @@ -547,8 +550,55 @@ enum nf_ct_sysctl_index { #ifdef CONFIG_NF_CONNTRACK_TIMESTAMP NF_SYSCTL_CT_TIMESTAMP, #endif + NF_SYSCTL_CT_PROTO_TIMEOUT_GENERIC, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_RECV, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_ESTABLISHED, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_FIN_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_LAST_ACK, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_TIME_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_RETRANS, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_UNACK, + NF_SYSCTL_CT_PROTO_TCP_LOOSE, + NF_SYSCTL_CT_PROTO_TCP_LIBERAL, + NF_SYSCTL_CT_PROTO_TCP_MAX_RETRANS, + NF_SYSCTL_CT_PROTO_TIMEOUT_UDP, + NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_STREAM, + NF_SYSCTL_CT_PROTO_TIMEOUT_ICMP, + NF_SYSCTL_CT_PROTO_TIMEOUT_ICMPV6, +#ifdef CONFIG_NF_CT_PROTO_SCTP + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_CLOSED, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_ECHOED, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_ESTABLISHED, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_RECD, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_ACK_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_ACKED, +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_REQUEST, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_RESPOND, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_PARTOPEN, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_OPEN, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSEREQ, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSING, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_TIMEWAIT, + NF_SYSCTL_CT_PROTO_DCCP_LOOSE, +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + NF_SYSCTL_CT_PROTO_TIMEOUT_GRE, + NF_SYSCTL_CT_PROTO_TIMEOUT_GRE_STREAM, +#endif + + __NF_SYSCTL_CT_LAST_SYSCTL, }; +#define NF_SYSCTL_CT_LAST_SYSCTL (__NF_SYSCTL_CT_LAST_SYSCTL + 1) + static struct ctl_table nf_ct_sysctl_table[] = { [NF_SYSCTL_CT_MAX] = { .procname = "nf_conntrack_max", @@ -626,7 +676,235 @@ static struct ctl_table nf_ct_sysctl_table[] = { .proc_handler = proc_dointvec, }, #endif - { } + [NF_SYSCTL_CT_PROTO_TIMEOUT_GENERIC] = { + .procname = "nf_conntrack_generic_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_SENT] = { + .procname = "nf_conntrack_tcp_timeout_syn_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_RECV] = { + .procname = "nf_conntrack_tcp_timeout_syn_recv", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_ESTABLISHED] = { + .procname = "nf_conntrack_tcp_timeout_established", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_FIN_WAIT] = { + .procname = "nf_conntrack_tcp_timeout_fin_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE_WAIT] = { + .procname = "nf_conntrack_tcp_timeout_close_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_LAST_ACK] = { + .procname = "nf_conntrack_tcp_timeout_last_ack", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_TIME_WAIT] = { + .procname = "nf_conntrack_tcp_timeout_time_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE] = { + .procname = "nf_conntrack_tcp_timeout_close", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_RETRANS] = { + .procname = "nf_conntrack_tcp_timeout_max_retrans", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_UNACK] = { + .procname = "nf_conntrack_tcp_timeout_unacknowledged", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TCP_LOOSE] = { + .procname = "nf_conntrack_tcp_loose", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + [NF_SYSCTL_CT_PROTO_TCP_LIBERAL] = { + .procname = "nf_conntrack_tcp_be_liberal", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + [NF_SYSCTL_CT_PROTO_TCP_MAX_RETRANS] = { + .procname = "nf_conntrack_tcp_max_retrans", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_UDP] = { + .procname = "nf_conntrack_udp_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_STREAM] = { + .procname = "nf_conntrack_udp_timeout_stream", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_ICMP] = { + .procname = "nf_conntrack_icmp_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_ICMPV6] = { + .procname = "nf_conntrack_icmpv6_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#ifdef CONFIG_NF_CT_PROTO_SCTP + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_CLOSED] = { + .procname = "nf_conntrack_sctp_timeout_closed", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_WAIT] = { + .procname = "nf_conntrack_sctp_timeout_cookie_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_ECHOED] = { + .procname = "nf_conntrack_sctp_timeout_cookie_echoed", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_ESTABLISHED] = { + .procname = "nf_conntrack_sctp_timeout_established", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_SENT] = { + .procname = "nf_conntrack_sctp_timeout_shutdown_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_RECD] = { + .procname = "nf_conntrack_sctp_timeout_shutdown_recd", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_ACK_SENT] = { + .procname = "nf_conntrack_sctp_timeout_shutdown_ack_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_SENT] = { + .procname = "nf_conntrack_sctp_timeout_heartbeat_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_ACKED] = { + .procname = "nf_conntrack_sctp_timeout_heartbeat_acked", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_REQUEST] = { + .procname = "nf_conntrack_dccp_timeout_request", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_RESPOND] = { + .procname = "nf_conntrack_dccp_timeout_respond", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_PARTOPEN] = { + .procname = "nf_conntrack_dccp_timeout_partopen", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_OPEN] = { + .procname = "nf_conntrack_dccp_timeout_open", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSEREQ] = { + .procname = "nf_conntrack_dccp_timeout_closereq", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSING] = { + .procname = "nf_conntrack_dccp_timeout_closing", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_TIMEWAIT] = { + .procname = "nf_conntrack_dccp_timeout_timewait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_DCCP_LOOSE] = { + .procname = "nf_conntrack_dccp_loose", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + [NF_SYSCTL_CT_PROTO_TIMEOUT_GRE] = { + .procname = "nf_conntrack_gre_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_GRE_STREAM] = { + .procname = "nf_conntrack_gre_timeout_stream", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#endif + {} }; static struct ctl_table nf_ct_netfilter_table[] = { @@ -640,14 +918,103 @@ static struct ctl_table nf_ct_netfilter_table[] = { { } }; +static void nf_conntrack_standalone_init_tcp_sysctl(struct net *net, + struct ctl_table *table) +{ + struct nf_tcp_net *tn = nf_tcp_pernet(net); + +#define XASSIGN(XNAME, tn) \ + table[NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_ ## XNAME].data = \ + &(tn)->timeouts[TCP_CONNTRACK_ ## XNAME] + + XASSIGN(SYN_SENT, tn); + XASSIGN(SYN_RECV, tn); + XASSIGN(ESTABLISHED, tn); + XASSIGN(FIN_WAIT, tn); + XASSIGN(CLOSE_WAIT, tn); + XASSIGN(LAST_ACK, tn); + XASSIGN(TIME_WAIT, tn); + XASSIGN(CLOSE, tn); + XASSIGN(RETRANS, tn); + XASSIGN(UNACK, tn); +#undef XASSIGN +#define XASSIGN(XNAME, rval) \ + table[NF_SYSCTL_CT_PROTO_TCP_ ## XNAME].data = (rval) + + XASSIGN(LOOSE, &tn->tcp_loose); + XASSIGN(LIBERAL, &tn->tcp_be_liberal); + XASSIGN(MAX_RETRANS, &tn->tcp_max_retrans); +#undef XASSIGN +} + +static void nf_conntrack_standalone_init_sctp_sysctl(struct net *net, + struct ctl_table *table) +{ +#ifdef CONFIG_NF_CT_PROTO_SCTP + struct nf_sctp_net *sn = nf_sctp_pernet(net); + +#define XASSIGN(XNAME, sn) \ + table[NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_ ## XNAME].data = \ + &(sn)->timeouts[SCTP_CONNTRACK_ ## XNAME] + + XASSIGN(CLOSED, sn); + XASSIGN(COOKIE_WAIT, sn); + XASSIGN(COOKIE_ECHOED, sn); + XASSIGN(ESTABLISHED, sn); + XASSIGN(SHUTDOWN_SENT, sn); + XASSIGN(SHUTDOWN_RECD, sn); + XASSIGN(SHUTDOWN_ACK_SENT, sn); + XASSIGN(HEARTBEAT_SENT, sn); + XASSIGN(HEARTBEAT_ACKED, sn); +#undef XASSIGN +#endif +} + +static void nf_conntrack_standalone_init_dccp_sysctl(struct net *net, + struct ctl_table *table) +{ +#ifdef CONFIG_NF_CT_PROTO_DCCP + struct nf_dccp_net *dn = nf_dccp_pernet(net); + +#define XASSIGN(XNAME, dn) \ + table[NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_ ## XNAME].data = \ + &(dn)->dccp_timeout[CT_DCCP_ ## XNAME] + + XASSIGN(REQUEST, dn); + XASSIGN(RESPOND, dn); + XASSIGN(PARTOPEN, dn); + XASSIGN(OPEN, dn); + XASSIGN(CLOSEREQ, dn); + XASSIGN(CLOSING, dn); + XASSIGN(TIMEWAIT, dn); +#undef XASSIGN + + table[NF_SYSCTL_CT_PROTO_DCCP_LOOSE].data = &dn->dccp_loose; +#endif +} + +static void nf_conntrack_standalone_init_gre_sysctl(struct net *net, + struct ctl_table *table) +{ +#ifdef CONFIG_NF_CT_PROTO_GRE + struct nf_gre_net *gn = nf_gre_pernet(net); + + table[NF_SYSCTL_CT_PROTO_TIMEOUT_GRE].data = &gn->timeouts[GRE_CT_UNREPLIED]; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_GRE_STREAM].data = &gn->timeouts[GRE_CT_REPLIED]; +#endif +} + static int nf_conntrack_standalone_init_sysctl(struct net *net) { + struct nf_udp_net *un = nf_udp_pernet(net); struct ctl_table *table; + BUILD_BUG_ON(ARRAY_SIZE(nf_ct_sysctl_table) != NF_SYSCTL_CT_LAST_SYSCTL); + table = kmemdup(nf_ct_sysctl_table, sizeof(nf_ct_sysctl_table), GFP_KERNEL); if (!table) - goto out_kmemdup; + return -ENOMEM; table[NF_SYSCTL_CT_COUNT].data = &net->ct.count; table[NF_SYSCTL_CT_CHECKSUM].data = &net->ct.sysctl_checksum; @@ -655,6 +1022,16 @@ static int nf_conntrack_standalone_init_sysctl(struct net *net) #ifdef CONFIG_NF_CONNTRACK_EVENTS table[NF_SYSCTL_CT_EVENTS].data = &net->ct.sysctl_events; #endif + table[NF_SYSCTL_CT_PROTO_TIMEOUT_GENERIC].data = &nf_generic_pernet(net)->timeout; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_ICMP].data = &nf_icmp_pernet(net)->timeout; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_ICMPV6].data = &nf_icmpv6_pernet(net)->timeout; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_UDP].data = &un->timeouts[UDP_CT_UNREPLIED]; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_STREAM].data = &un->timeouts[UDP_CT_REPLIED]; + + nf_conntrack_standalone_init_tcp_sysctl(net, table); + nf_conntrack_standalone_init_sctp_sysctl(net, table); + nf_conntrack_standalone_init_dccp_sysctl(net, table); + nf_conntrack_standalone_init_gre_sysctl(net, table); /* Don't export sysctls to unprivileged users */ if (net->user_ns != &init_user_ns) { @@ -680,7 +1057,6 @@ static int nf_conntrack_standalone_init_sysctl(struct net *net) out_unregister_netfilter: kfree(table); -out_kmemdup: return -ENOMEM; } @@ -703,31 +1079,47 @@ static void nf_conntrack_standalone_fini_sysctl(struct net *net) } #endif /* CONFIG_SYSCTL */ +static void nf_conntrack_fini_net(struct net *net) +{ + if (enable_hooks) + nf_ct_netns_put(net, NFPROTO_INET); + + nf_conntrack_standalone_fini_proc(net); + nf_conntrack_standalone_fini_sysctl(net); +} + static int nf_conntrack_pernet_init(struct net *net) { int ret; - ret = nf_conntrack_init_net(net); + net->ct.sysctl_checksum = 1; + + ret = nf_conntrack_standalone_init_sysctl(net); if (ret < 0) - goto out_init; + return ret; ret = nf_conntrack_standalone_init_proc(net); if (ret < 0) goto out_proc; - net->ct.sysctl_checksum = 1; - net->ct.sysctl_log_invalid = 0; - ret = nf_conntrack_standalone_init_sysctl(net); + ret = nf_conntrack_init_net(net); if (ret < 0) - goto out_sysctl; + goto out_init_net; + + if (enable_hooks) { + ret = nf_ct_netns_get(net, NFPROTO_INET); + if (ret < 0) + goto out_hooks; + } return 0; -out_sysctl: +out_hooks: + nf_conntrack_cleanup_net(net); +out_init_net: nf_conntrack_standalone_fini_proc(net); out_proc: - nf_conntrack_cleanup_net(net); -out_init: + nf_conntrack_standalone_fini_sysctl(net); return ret; } @@ -735,10 +1127,9 @@ static void nf_conntrack_pernet_exit(struct list_head *net_exit_list) { struct net *net; - list_for_each_entry(net, net_exit_list, exit_list) { - nf_conntrack_standalone_fini_sysctl(net); - nf_conntrack_standalone_fini_proc(net); - } + list_for_each_entry(net, net_exit_list, exit_list) + nf_conntrack_fini_net(net); + nf_conntrack_cleanup_net_list(net_exit_list); } diff --git a/net/netfilter/nf_flow_table_core.c b/net/netfilter/nf_flow_table_core.c index c0c72ae9df42..7aabfd4b1e50 100644 --- a/net/netfilter/nf_flow_table_core.c +++ b/net/netfilter/nf_flow_table_core.c @@ -121,7 +121,7 @@ static void flow_offload_fixup_ct_state(struct nf_conn *ct) if (l4num == IPPROTO_TCP) flow_offload_fixup_tcp(&ct->proto.tcp); - l4proto = __nf_ct_l4proto_find(l4num); + l4proto = nf_ct_l4proto_find(l4num); if (!l4proto) return; diff --git a/net/netfilter/nf_nat_core.c b/net/netfilter/nf_nat_core.c index d159e9e7835b..af7dc6537758 100644 --- a/net/netfilter/nf_nat_core.c +++ b/net/netfilter/nf_nat_core.c @@ -22,8 +22,6 @@ #include <net/netfilter/nf_conntrack.h> #include <net/netfilter/nf_conntrack_core.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/netfilter/nf_nat_core.h> #include <net/netfilter/nf_nat_helper.h> #include <net/netfilter/nf_conntrack_helper.h> #include <net/netfilter/nf_conntrack_seqadj.h> @@ -35,8 +33,6 @@ static spinlock_t nf_nat_locks[CONNTRACK_LOCKS]; static DEFINE_MUTEX(nf_nat_proto_mutex); -static const struct nf_nat_l3proto __rcu *nf_nat_l3protos[NFPROTO_NUMPROTO] - __read_mostly; static unsigned int nat_net_id __read_mostly; static struct hlist_head *nf_nat_bysource __read_mostly; @@ -58,16 +54,75 @@ struct nat_net { struct nf_nat_hooks_net nat_proto_net[NFPROTO_NUMPROTO]; }; -inline const struct nf_nat_l3proto * -__nf_nat_l3proto_find(u8 family) +#ifdef CONFIG_XFRM +static void nf_nat_ipv4_decode_session(struct sk_buff *skb, + const struct nf_conn *ct, + enum ip_conntrack_dir dir, + unsigned long statusbit, + struct flowi *fl) { - return rcu_dereference(nf_nat_l3protos[family]); + const struct nf_conntrack_tuple *t = &ct->tuplehash[dir].tuple; + struct flowi4 *fl4 = &fl->u.ip4; + + if (ct->status & statusbit) { + fl4->daddr = t->dst.u3.ip; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl4->fl4_dport = t->dst.u.all; + } + + statusbit ^= IPS_NAT_MASK; + + if (ct->status & statusbit) { + fl4->saddr = t->src.u3.ip; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl4->fl4_sport = t->src.u.all; + } +} + +static void nf_nat_ipv6_decode_session(struct sk_buff *skb, + const struct nf_conn *ct, + enum ip_conntrack_dir dir, + unsigned long statusbit, + struct flowi *fl) +{ +#if IS_ENABLED(CONFIG_IPV6) + const struct nf_conntrack_tuple *t = &ct->tuplehash[dir].tuple; + struct flowi6 *fl6 = &fl->u.ip6; + + if (ct->status & statusbit) { + fl6->daddr = t->dst.u3.in6; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl6->fl6_dport = t->dst.u.all; + } + + statusbit ^= IPS_NAT_MASK; + + if (ct->status & statusbit) { + fl6->saddr = t->src.u3.in6; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl6->fl6_sport = t->src.u.all; + } +#endif } -#ifdef CONFIG_XFRM static void __nf_nat_decode_session(struct sk_buff *skb, struct flowi *fl) { - const struct nf_nat_l3proto *l3proto; const struct nf_conn *ct; enum ip_conntrack_info ctinfo; enum ip_conntrack_dir dir; @@ -79,17 +134,20 @@ static void __nf_nat_decode_session(struct sk_buff *skb, struct flowi *fl) return; family = nf_ct_l3num(ct); - l3proto = __nf_nat_l3proto_find(family); - if (l3proto == NULL) - return; - dir = CTINFO2DIR(ctinfo); if (dir == IP_CT_DIR_ORIGINAL) statusbit = IPS_DST_NAT; else statusbit = IPS_SRC_NAT; - l3proto->decode_session(skb, ct, dir, statusbit, fl); + switch (family) { + case NFPROTO_IPV4: + nf_nat_ipv4_decode_session(skb, ct, dir, statusbit, fl); + return; + case NFPROTO_IPV6: + nf_nat_ipv6_decode_session(skb, ct, dir, statusbit, fl); + return; + } } int nf_xfrm_me_harder(struct net *net, struct sk_buff *skb, unsigned int family) @@ -146,7 +204,7 @@ hash_by_src(const struct net *n, const struct nf_conntrack_tuple *tuple) } /* Is this tuple already taken? (not by us) */ -int +static int nf_nat_used_tuple(const struct nf_conntrack_tuple *tuple, const struct nf_conn *ignored_conntrack) { @@ -158,10 +216,9 @@ nf_nat_used_tuple(const struct nf_conntrack_tuple *tuple, */ struct nf_conntrack_tuple reply; - nf_ct_invert_tuplepr(&reply, tuple); + nf_ct_invert_tuple(&reply, tuple); return nf_conntrack_tuple_taken(&reply, ignored_conntrack); } -EXPORT_SYMBOL(nf_nat_used_tuple); static bool nf_nat_inet_in_range(const struct nf_conntrack_tuple *t, const struct nf_nat_range2 *range) @@ -183,7 +240,7 @@ static bool l4proto_in_range(const struct nf_conntrack_tuple *tuple, __be16 port; switch (tuple->dst.protonum) { - case IPPROTO_ICMP: /* fallthrough */ + case IPPROTO_ICMP: case IPPROTO_ICMPV6: return ntohs(tuple->src.u.icmp.id) >= ntohs(min->icmp.id) && ntohs(tuple->src.u.icmp.id) <= ntohs(max->icmp.id); @@ -253,7 +310,7 @@ find_appropriate_src(struct net *net, net_eq(net, nf_ct_net(ct)) && nf_ct_zone_equal(ct, zone, IP_CT_DIR_ORIGINAL)) { /* Copy source part from reply tuple. */ - nf_ct_invert_tuplepr(result, + nf_ct_invert_tuple(result, &ct->tuplehash[IP_CT_DIR_REPLY].tuple); result->dst = tuple->dst; @@ -560,8 +617,8 @@ nf_nat_setup_info(struct nf_conn *ct, * manipulations (future optimization: if num_manips == 0, * orig_tp = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple) */ - nf_ct_invert_tuplepr(&curr_tuple, - &ct->tuplehash[IP_CT_DIR_REPLY].tuple); + nf_ct_invert_tuple(&curr_tuple, + &ct->tuplehash[IP_CT_DIR_REPLY].tuple); get_unique_tuple(&new_tuple, &curr_tuple, range, ct, maniptype); @@ -569,7 +626,7 @@ nf_nat_setup_info(struct nf_conn *ct, struct nf_conntrack_tuple reply; /* Alter conntrack table so will recognize replies. */ - nf_ct_invert_tuplepr(&reply, &new_tuple); + nf_ct_invert_tuple(&reply, &new_tuple); nf_conntrack_alter_reply(ct, &reply); /* Non-atomic: we own this at the moment. */ @@ -632,23 +689,6 @@ nf_nat_alloc_null_binding(struct nf_conn *ct, unsigned int hooknum) } EXPORT_SYMBOL_GPL(nf_nat_alloc_null_binding); -static unsigned int nf_nat_manip_pkt(struct sk_buff *skb, struct nf_conn *ct, - enum nf_nat_manip_type mtype, - enum ip_conntrack_dir dir) -{ - const struct nf_nat_l3proto *l3proto; - struct nf_conntrack_tuple target; - - /* We are aiming to look like inverse of other direction. */ - nf_ct_invert_tuplepr(&target, &ct->tuplehash[!dir].tuple); - - l3proto = __nf_nat_l3proto_find(target.src.l3num); - if (!l3proto->manip_pkt(skb, 0, &target, mtype)) - return NF_DROP; - - return NF_ACCEPT; -} - /* Do packet manipulations according to nf_nat_setup_info. */ unsigned int nf_nat_packet(struct nf_conn *ct, enum ip_conntrack_info ctinfo, @@ -799,33 +839,6 @@ static int nf_nat_proto_clean(struct nf_conn *ct, void *data) return 0; } -static void nf_nat_l3proto_clean(u8 l3proto) -{ - struct nf_nat_proto_clean clean = { - .l3proto = l3proto, - }; - - nf_ct_iterate_destroy(nf_nat_proto_remove, &clean); -} - -int nf_nat_l3proto_register(const struct nf_nat_l3proto *l3proto) -{ - RCU_INIT_POINTER(nf_nat_l3protos[l3proto->l3proto], l3proto); - return 0; -} -EXPORT_SYMBOL_GPL(nf_nat_l3proto_register); - -void nf_nat_l3proto_unregister(const struct nf_nat_l3proto *l3proto) -{ - mutex_lock(&nf_nat_proto_mutex); - RCU_INIT_POINTER(nf_nat_l3protos[l3proto->l3proto], NULL); - mutex_unlock(&nf_nat_proto_mutex); - synchronize_rcu(); - - nf_nat_l3proto_clean(l3proto->l3proto); -} -EXPORT_SYMBOL_GPL(nf_nat_l3proto_unregister); - /* No one using conntrack by the time this called. */ static void nf_nat_cleanup_conntrack(struct nf_conn *ct) { @@ -888,10 +901,43 @@ static const struct nla_policy nat_nla_policy[CTA_NAT_MAX+1] = { [CTA_NAT_PROTO] = { .type = NLA_NESTED }, }; +static int nf_nat_ipv4_nlattr_to_range(struct nlattr *tb[], + struct nf_nat_range2 *range) +{ + if (tb[CTA_NAT_V4_MINIP]) { + range->min_addr.ip = nla_get_be32(tb[CTA_NAT_V4_MINIP]); + range->flags |= NF_NAT_RANGE_MAP_IPS; + } + + if (tb[CTA_NAT_V4_MAXIP]) + range->max_addr.ip = nla_get_be32(tb[CTA_NAT_V4_MAXIP]); + else + range->max_addr.ip = range->min_addr.ip; + + return 0; +} + +static int nf_nat_ipv6_nlattr_to_range(struct nlattr *tb[], + struct nf_nat_range2 *range) +{ + if (tb[CTA_NAT_V6_MINIP]) { + nla_memcpy(&range->min_addr.ip6, tb[CTA_NAT_V6_MINIP], + sizeof(struct in6_addr)); + range->flags |= NF_NAT_RANGE_MAP_IPS; + } + + if (tb[CTA_NAT_V6_MAXIP]) + nla_memcpy(&range->max_addr.ip6, tb[CTA_NAT_V6_MAXIP], + sizeof(struct in6_addr)); + else + range->max_addr = range->min_addr; + + return 0; +} + static int nfnetlink_parse_nat(const struct nlattr *nat, - const struct nf_conn *ct, struct nf_nat_range2 *range, - const struct nf_nat_l3proto *l3proto) + const struct nf_conn *ct, struct nf_nat_range2 *range) { struct nlattr *tb[CTA_NAT_MAX+1]; int err; @@ -902,8 +948,19 @@ nfnetlink_parse_nat(const struct nlattr *nat, if (err < 0) return err; - err = l3proto->nlattr_to_range(tb, range); - if (err < 0) + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + err = nf_nat_ipv4_nlattr_to_range(tb, range); + break; + case NFPROTO_IPV6: + err = nf_nat_ipv6_nlattr_to_range(tb, range); + break; + default: + err = -EPROTONOSUPPORT; + break; + } + + if (err) return err; if (!tb[CTA_NAT_PROTO]) @@ -919,7 +976,6 @@ nfnetlink_parse_nat_setup(struct nf_conn *ct, const struct nlattr *attr) { struct nf_nat_range2 range; - const struct nf_nat_l3proto *l3proto; int err; /* Should not happen, restricted to creating new conntracks @@ -928,18 +984,11 @@ nfnetlink_parse_nat_setup(struct nf_conn *ct, if (WARN_ON_ONCE(nf_nat_initialized(ct, manip))) return -EEXIST; - /* Make sure that L3 NAT is there by when we call nf_nat_setup_info to - * attach the null binding, otherwise this may oops. - */ - l3proto = __nf_nat_l3proto_find(nf_ct_l3num(ct)); - if (l3proto == NULL) - return -EAGAIN; - /* No NAT information has been passed, allocate the null-binding */ if (attr == NULL) return __nf_nat_alloc_null_binding(ct, manip) == NF_DROP ? -ENOMEM : 0; - err = nfnetlink_parse_nat(attr, ct, &range, l3proto); + err = nfnetlink_parse_nat(attr, ct, &range); if (err < 0) return err; @@ -1036,7 +1085,6 @@ int nf_nat_register_fn(struct net *net, const struct nf_hook_ops *ops, mutex_unlock(&nf_nat_proto_mutex); return ret; } -EXPORT_SYMBOL_GPL(nf_nat_register_fn); void nf_nat_unregister_fn(struct net *net, const struct nf_hook_ops *ops, unsigned int ops_count) @@ -1085,7 +1133,6 @@ void nf_nat_unregister_fn(struct net *net, const struct nf_hook_ops *ops, unlock: mutex_unlock(&nf_nat_proto_mutex); } -EXPORT_SYMBOL_GPL(nf_nat_unregister_fn); static struct pernet_operations nat_net_ops = { .id = &nat_net_id, diff --git a/net/netfilter/nf_nat_helper.c b/net/netfilter/nf_nat_helper.c index 38793b95d9bc..ccc06f7539d7 100644 --- a/net/netfilter/nf_nat_helper.c +++ b/net/netfilter/nf_nat_helper.c @@ -22,9 +22,6 @@ #include <net/netfilter/nf_conntrack_expect.h> #include <net/netfilter/nf_conntrack_seqadj.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/netfilter/nf_nat_l4proto.h> -#include <net/netfilter/nf_nat_core.h> #include <net/netfilter/nf_nat_helper.h> /* Frobs data inside this packet, which is linear. */ @@ -98,7 +95,6 @@ bool __nf_nat_mangle_tcp_packet(struct sk_buff *skb, const char *rep_buffer, unsigned int rep_len, bool adjust) { - const struct nf_nat_l3proto *l3proto; struct tcphdr *tcph; int oldlen, datalen; @@ -118,9 +114,8 @@ bool __nf_nat_mangle_tcp_packet(struct sk_buff *skb, datalen = skb->len - protoff; - l3proto = __nf_nat_l3proto_find(nf_ct_l3num(ct)); - l3proto->csum_recalc(skb, IPPROTO_TCP, tcph, &tcph->check, - datalen, oldlen); + nf_nat_csum_recalc(skb, nf_ct_l3num(ct), IPPROTO_TCP, + tcph, &tcph->check, datalen, oldlen); if (adjust && rep_len != match_len) nf_ct_seqadj_set(ct, ctinfo, tcph->seq, @@ -150,7 +145,6 @@ nf_nat_mangle_udp_packet(struct sk_buff *skb, const char *rep_buffer, unsigned int rep_len) { - const struct nf_nat_l3proto *l3proto; struct udphdr *udph; int datalen, oldlen; @@ -176,9 +170,8 @@ nf_nat_mangle_udp_packet(struct sk_buff *skb, if (!udph->check && skb->ip_summed != CHECKSUM_PARTIAL) return true; - l3proto = __nf_nat_l3proto_find(nf_ct_l3num(ct)); - l3proto->csum_recalc(skb, IPPROTO_UDP, udph, &udph->check, - datalen, oldlen); + nf_nat_csum_recalc(skb, nf_ct_l3num(ct), IPPROTO_TCP, + udph, &udph->check, datalen, oldlen); return true; } diff --git a/net/ipv4/netfilter/nf_nat_masquerade_ipv4.c b/net/netfilter/nf_nat_masquerade.c index 41327bb99093..d85c4d902e7b 100644 --- a/net/ipv4/netfilter/nf_nat_masquerade_ipv4.c +++ b/net/netfilter/nf_nat_masquerade.c @@ -1,25 +1,18 @@ -/* (C) 1999-2001 Paul `Rusty' Russell - * (C) 2002-2006 Netfilter Core Team <coreteam@netfilter.org> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 as - * published by the Free Software Foundation. - */ +// SPDX-License-Identifier: GPL-2.0 #include <linux/types.h> #include <linux/atomic.h> #include <linux/inetdevice.h> -#include <linux/ip.h> -#include <linux/timer.h> #include <linux/netfilter.h> -#include <net/protocol.h> -#include <net/ip.h> -#include <net/checksum.h> -#include <net/route.h> #include <linux/netfilter_ipv4.h> -#include <linux/netfilter/x_tables.h> -#include <net/netfilter/nf_nat.h> +#include <linux/netfilter_ipv6.h> + #include <net/netfilter/ipv4/nf_nat_masquerade.h> +#include <net/netfilter/ipv6/nf_nat_masquerade.h> + +static DEFINE_MUTEX(masq_mutex); +static unsigned int masq_refcnt4 __read_mostly; +static unsigned int masq_refcnt6 __read_mostly; unsigned int nf_nat_masquerade_ipv4(struct sk_buff *skb, unsigned int hooknum, @@ -78,8 +71,6 @@ static int device_cmp(struct nf_conn *i, void *ifindex) if (!nat) return 0; - if (nf_ct_l3num(i) != NFPROTO_IPV4) - return 0; return nat->masq_index == (int)(long)ifindex; } @@ -95,7 +86,6 @@ static int masq_device_event(struct notifier_block *this, * conntracks which were associated with that device, * and forget them. */ - WARN_ON(dev->ifindex == 0); nf_ct_iterate_cleanup_net(net, device_cmp, (void *)(long)dev->ifindex, 0, 0); @@ -147,16 +137,18 @@ static struct notifier_block masq_inet_notifier = { .notifier_call = masq_inet_event, }; -static int masq_refcnt; -static DEFINE_MUTEX(masq_mutex); - int nf_nat_masquerade_ipv4_register_notifier(void) { int ret = 0; mutex_lock(&masq_mutex); + if (WARN_ON_ONCE(masq_refcnt4 == UINT_MAX)) { + ret = -EOVERFLOW; + goto out_unlock; + } + /* check if the notifier was already set */ - if (++masq_refcnt > 1) + if (++masq_refcnt4 > 1) goto out_unlock; /* Register for device down reports */ @@ -174,7 +166,7 @@ int nf_nat_masquerade_ipv4_register_notifier(void) err_unregister: unregister_netdevice_notifier(&masq_dev_notifier); err_dec: - masq_refcnt--; + masq_refcnt4--; out_unlock: mutex_unlock(&masq_mutex); return ret; @@ -185,7 +177,7 @@ void nf_nat_masquerade_ipv4_unregister_notifier(void) { mutex_lock(&masq_mutex); /* check if the notifier still has clients */ - if (--masq_refcnt > 0) + if (--masq_refcnt4 > 0) goto out_unlock; unregister_netdevice_notifier(&masq_dev_notifier); @@ -194,3 +186,180 @@ out_unlock: mutex_unlock(&masq_mutex); } EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv4_unregister_notifier); + +#if IS_ENABLED(CONFIG_IPV6) +static atomic_t v6_worker_count __read_mostly; + +static int +nat_ipv6_dev_get_saddr(struct net *net, const struct net_device *dev, + const struct in6_addr *daddr, unsigned int srcprefs, + struct in6_addr *saddr) +{ +#ifdef CONFIG_IPV6_MODULE + const struct nf_ipv6_ops *v6_ops = nf_get_ipv6_ops(); + + if (!v6_ops) + return -EHOSTUNREACH; + + return v6_ops->dev_get_saddr(net, dev, daddr, srcprefs, saddr); +#else + return ipv6_dev_get_saddr(net, dev, daddr, srcprefs, saddr); +#endif +} + +unsigned int +nf_nat_masquerade_ipv6(struct sk_buff *skb, const struct nf_nat_range2 *range, + const struct net_device *out) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn_nat *nat; + struct in6_addr src; + struct nf_conn *ct; + struct nf_nat_range2 newrange; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct && (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY))); + + if (nat_ipv6_dev_get_saddr(nf_ct_net(ct), out, + &ipv6_hdr(skb)->daddr, 0, &src) < 0) + return NF_DROP; + + nat = nf_ct_nat_ext_add(ct); + if (nat) + nat->masq_index = out->ifindex; + + newrange.flags = range->flags | NF_NAT_RANGE_MAP_IPS; + newrange.min_addr.in6 = src; + newrange.max_addr.in6 = src; + newrange.min_proto = range->min_proto; + newrange.max_proto = range->max_proto; + + return nf_nat_setup_info(ct, &newrange, NF_NAT_MANIP_SRC); +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6); + +struct masq_dev_work { + struct work_struct work; + struct net *net; + struct in6_addr addr; + int ifindex; +}; + +static int inet6_cmp(struct nf_conn *ct, void *work) +{ + struct masq_dev_work *w = (struct masq_dev_work *)work; + struct nf_conntrack_tuple *tuple; + + if (!device_cmp(ct, (void *)(long)w->ifindex)) + return 0; + + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + return ipv6_addr_equal(&w->addr, &tuple->dst.u3.in6); +} + +static void iterate_cleanup_work(struct work_struct *work) +{ + struct masq_dev_work *w; + + w = container_of(work, struct masq_dev_work, work); + + nf_ct_iterate_cleanup_net(w->net, inet6_cmp, (void *)w, 0, 0); + + put_net(w->net); + kfree(w); + atomic_dec(&v6_worker_count); + module_put(THIS_MODULE); +} + +/* atomic notifier; can't call nf_ct_iterate_cleanup_net (it can sleep). + * + * Defer it to the system workqueue. + * + * As we can have 'a lot' of inet_events (depending on amount of ipv6 + * addresses being deleted), we also need to limit work item queue. + */ +static int masq_inet6_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct inet6_ifaddr *ifa = ptr; + const struct net_device *dev; + struct masq_dev_work *w; + struct net *net; + + if (event != NETDEV_DOWN || atomic_read(&v6_worker_count) >= 16) + return NOTIFY_DONE; + + dev = ifa->idev->dev; + net = maybe_get_net(dev_net(dev)); + if (!net) + return NOTIFY_DONE; + + if (!try_module_get(THIS_MODULE)) + goto err_module; + + w = kmalloc(sizeof(*w), GFP_ATOMIC); + if (w) { + atomic_inc(&v6_worker_count); + + INIT_WORK(&w->work, iterate_cleanup_work); + w->ifindex = dev->ifindex; + w->net = net; + w->addr = ifa->addr; + schedule_work(&w->work); + + return NOTIFY_DONE; + } + + module_put(THIS_MODULE); + err_module: + put_net(net); + return NOTIFY_DONE; +} + +static struct notifier_block masq_inet6_notifier = { + .notifier_call = masq_inet6_event, +}; + +int nf_nat_masquerade_ipv6_register_notifier(void) +{ + int ret = 0; + + mutex_lock(&masq_mutex); + if (WARN_ON_ONCE(masq_refcnt6 == UINT_MAX)) { + ret = -EOVERFLOW; + goto out_unlock; + } + + /* check if the notifier is already set */ + if (++masq_refcnt6 > 1) + goto out_unlock; + + ret = register_inet6addr_notifier(&masq_inet6_notifier); + if (ret) + goto err_dec; + + mutex_unlock(&masq_mutex); + return ret; +err_dec: + masq_refcnt6--; +out_unlock: + mutex_unlock(&masq_mutex); + return ret; +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6_register_notifier); + +void nf_nat_masquerade_ipv6_unregister_notifier(void) +{ + mutex_lock(&masq_mutex); + /* check if the notifier still has clients */ + if (--masq_refcnt6 > 0) + goto out_unlock; + + unregister_inet6addr_notifier(&masq_inet6_notifier); +out_unlock: + mutex_unlock(&masq_mutex); +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6_unregister_notifier); +#endif diff --git a/net/netfilter/nf_nat_proto.c b/net/netfilter/nf_nat_proto.c index f83bf9d8c9f5..62743da3004f 100644 --- a/net/netfilter/nf_nat_proto.c +++ b/net/netfilter/nf_nat_proto.c @@ -20,13 +20,26 @@ #include <linux/netfilter.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_nat_l3proto.h> -#include <net/netfilter/nf_nat_l4proto.h> + +#include <linux/ipv6.h> +#include <linux/netfilter_ipv6.h> +#include <net/checksum.h> +#include <net/ip6_checksum.h> +#include <net/ip6_route.h> +#include <net/xfrm.h> +#include <net/ipv6.h> + +#include <net/netfilter/nf_conntrack_core.h> +#include <net/netfilter/nf_conntrack.h> +#include <linux/netfilter/nfnetlink_conntrack.h> + +static void nf_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype); static void __udp_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, struct udphdr *hdr, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype, bool do_csum) @@ -43,8 +56,7 @@ __udp_manip_pkt(struct sk_buff *skb, portptr = &hdr->dest; } if (do_csum) { - l3proto->csum_update(skb, iphdroff, &hdr->check, - tuple, maniptype); + nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype); inet_proto_csum_replace2(&hdr->check, skb, *portptr, newport, false); if (!hdr->check) @@ -54,7 +66,6 @@ __udp_manip_pkt(struct sk_buff *skb, } static bool udp_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -68,12 +79,11 @@ static bool udp_manip_pkt(struct sk_buff *skb, hdr = (struct udphdr *)(skb->data + hdroff); do_csum = hdr->check || skb->ip_summed == CHECKSUM_PARTIAL; - __udp_manip_pkt(skb, l3proto, iphdroff, hdr, tuple, maniptype, do_csum); + __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, do_csum); return true; } static bool udplite_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -85,14 +95,13 @@ static bool udplite_manip_pkt(struct sk_buff *skb, return false; hdr = (struct udphdr *)(skb->data + hdroff); - __udp_manip_pkt(skb, l3proto, iphdroff, hdr, tuple, maniptype, true); + __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, true); #endif return true; } static bool sctp_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -135,7 +144,6 @@ sctp_manip_pkt(struct sk_buff *skb, static bool tcp_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -171,14 +179,13 @@ tcp_manip_pkt(struct sk_buff *skb, if (hdrsize < sizeof(*hdr)) return true; - l3proto->csum_update(skb, iphdroff, &hdr->check, tuple, maniptype); + nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype); inet_proto_csum_replace2(&hdr->check, skb, oldport, newport, false); return true; } static bool dccp_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -210,8 +217,7 @@ dccp_manip_pkt(struct sk_buff *skb, if (hdrsize < sizeof(*hdr)) return true; - l3proto->csum_update(skb, iphdroff, &hdr->dccph_checksum, - tuple, maniptype); + nf_csum_update(skb, iphdroff, &hdr->dccph_checksum, tuple, maniptype); inet_proto_csum_replace2(&hdr->dccph_checksum, skb, oldport, newport, false); #endif @@ -220,7 +226,6 @@ dccp_manip_pkt(struct sk_buff *skb, static bool icmp_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -239,7 +244,6 @@ icmp_manip_pkt(struct sk_buff *skb, static bool icmpv6_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -250,8 +254,7 @@ icmpv6_manip_pkt(struct sk_buff *skb, return false; hdr = (struct icmp6hdr *)(skb->data + hdroff); - l3proto->csum_update(skb, iphdroff, &hdr->icmp6_cksum, - tuple, maniptype); + nf_csum_update(skb, iphdroff, &hdr->icmp6_cksum, tuple, maniptype); if (hdr->icmp6_type == ICMPV6_ECHO_REQUEST || hdr->icmp6_type == ICMPV6_ECHO_REPLY) { inet_proto_csum_replace2(&hdr->icmp6_cksum, skb, @@ -265,7 +268,6 @@ icmpv6_manip_pkt(struct sk_buff *skb, /* manipulate a GRE packet according to maniptype */ static bool gre_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) @@ -304,40 +306,718 @@ gre_manip_pkt(struct sk_buff *skb, return true; } -bool nf_nat_l4proto_manip_pkt(struct sk_buff *skb, - const struct nf_nat_l3proto *l3proto, +static bool l4proto_manip_pkt(struct sk_buff *skb, unsigned int iphdroff, unsigned int hdroff, const struct nf_conntrack_tuple *tuple, enum nf_nat_manip_type maniptype) { switch (tuple->dst.protonum) { case IPPROTO_TCP: - return tcp_manip_pkt(skb, l3proto, iphdroff, hdroff, + return tcp_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_UDP: - return udp_manip_pkt(skb, l3proto, iphdroff, hdroff, + return udp_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_UDPLITE: - return udplite_manip_pkt(skb, l3proto, iphdroff, hdroff, + return udplite_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_SCTP: - return sctp_manip_pkt(skb, l3proto, iphdroff, hdroff, + return sctp_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_ICMP: - return icmp_manip_pkt(skb, l3proto, iphdroff, hdroff, + return icmp_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_ICMPV6: - return icmpv6_manip_pkt(skb, l3proto, iphdroff, hdroff, + return icmpv6_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_DCCP: - return dccp_manip_pkt(skb, l3proto, iphdroff, hdroff, + return dccp_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); case IPPROTO_GRE: - return gre_manip_pkt(skb, l3proto, iphdroff, hdroff, + return gre_manip_pkt(skb, iphdroff, hdroff, tuple, maniptype); } /* If we don't know protocol -- no error, pass it unmodified. */ return true; } -EXPORT_SYMBOL_GPL(nf_nat_l4proto_manip_pkt); + +static bool nf_nat_ipv4_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, + const struct nf_conntrack_tuple *target, + enum nf_nat_manip_type maniptype) +{ + struct iphdr *iph; + unsigned int hdroff; + + if (!skb_make_writable(skb, iphdroff + sizeof(*iph))) + return false; + + iph = (void *)skb->data + iphdroff; + hdroff = iphdroff + iph->ihl * 4; + + if (!l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype)) + return false; + iph = (void *)skb->data + iphdroff; + + if (maniptype == NF_NAT_MANIP_SRC) { + csum_replace4(&iph->check, iph->saddr, target->src.u3.ip); + iph->saddr = target->src.u3.ip; + } else { + csum_replace4(&iph->check, iph->daddr, target->dst.u3.ip); + iph->daddr = target->dst.u3.ip; + } + return true; +} + +static bool nf_nat_ipv6_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, + const struct nf_conntrack_tuple *target, + enum nf_nat_manip_type maniptype) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6hdr *ipv6h; + __be16 frag_off; + int hdroff; + u8 nexthdr; + + if (!skb_make_writable(skb, iphdroff + sizeof(*ipv6h))) + return false; + + ipv6h = (void *)skb->data + iphdroff; + nexthdr = ipv6h->nexthdr; + hdroff = ipv6_skip_exthdr(skb, iphdroff + sizeof(*ipv6h), + &nexthdr, &frag_off); + if (hdroff < 0) + goto manip_addr; + + if ((frag_off & htons(~0x7)) == 0 && + !l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype)) + return false; + + /* must reload, offset might have changed */ + ipv6h = (void *)skb->data + iphdroff; + +manip_addr: + if (maniptype == NF_NAT_MANIP_SRC) + ipv6h->saddr = target->src.u3.in6; + else + ipv6h->daddr = target->dst.u3.in6; + +#endif + return true; +} + +unsigned int nf_nat_manip_pkt(struct sk_buff *skb, struct nf_conn *ct, + enum nf_nat_manip_type mtype, + enum ip_conntrack_dir dir) +{ + struct nf_conntrack_tuple target; + + /* We are aiming to look like inverse of other direction. */ + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + + switch (target.src.l3num) { + case NFPROTO_IPV6: + if (nf_nat_ipv6_manip_pkt(skb, 0, &target, mtype)) + return NF_ACCEPT; + break; + case NFPROTO_IPV4: + if (nf_nat_ipv4_manip_pkt(skb, 0, &target, mtype)) + return NF_ACCEPT; + break; + default: + WARN_ON_ONCE(1); + break; + } + + return NF_DROP; +} + +static void nf_nat_ipv4_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype) +{ + struct iphdr *iph = (struct iphdr *)(skb->data + iphdroff); + __be32 oldip, newip; + + if (maniptype == NF_NAT_MANIP_SRC) { + oldip = iph->saddr; + newip = t->src.u3.ip; + } else { + oldip = iph->daddr; + newip = t->dst.u3.ip; + } + inet_proto_csum_replace4(check, skb, oldip, newip, true); +} + +static void nf_nat_ipv6_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype) +{ +#if IS_ENABLED(CONFIG_IPV6) + const struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + iphdroff); + const struct in6_addr *oldip, *newip; + + if (maniptype == NF_NAT_MANIP_SRC) { + oldip = &ipv6h->saddr; + newip = &t->src.u3.in6; + } else { + oldip = &ipv6h->daddr; + newip = &t->dst.u3.in6; + } + inet_proto_csum_replace16(check, skb, oldip->s6_addr32, + newip->s6_addr32, true); +#endif +} + +static void nf_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype) +{ + switch (t->src.l3num) { + case NFPROTO_IPV4: + nf_nat_ipv4_csum_update(skb, iphdroff, check, t, maniptype); + return; + case NFPROTO_IPV6: + nf_nat_ipv6_csum_update(skb, iphdroff, check, t, maniptype); + return; + } +} + +static void nf_nat_ipv4_csum_recalc(struct sk_buff *skb, + u8 proto, void *data, __sum16 *check, + int datalen, int oldlen) +{ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + const struct iphdr *iph = ip_hdr(skb); + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) + + ip_hdrlen(skb); + skb->csum_offset = (void *)check - data; + *check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, + proto, 0); + } else { + inet_proto_csum_replace2(check, skb, + htons(oldlen), htons(datalen), true); + } +} + +#if IS_ENABLED(CONFIG_IPV6) +static void nf_nat_ipv6_csum_recalc(struct sk_buff *skb, + u8 proto, void *data, __sum16 *check, + int datalen, int oldlen) +{ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) + + (data - (void *)skb->data); + skb->csum_offset = (void *)check - data; + *check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + datalen, proto, 0); + } else { + inet_proto_csum_replace2(check, skb, + htons(oldlen), htons(datalen), true); + } +} +#endif + +void nf_nat_csum_recalc(struct sk_buff *skb, + u8 nfproto, u8 proto, void *data, __sum16 *check, + int datalen, int oldlen) +{ + switch (nfproto) { + case NFPROTO_IPV4: + nf_nat_ipv4_csum_recalc(skb, proto, data, check, + datalen, oldlen); + return; +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: + nf_nat_ipv6_csum_recalc(skb, proto, data, check, + datalen, oldlen); + return; +#endif + } + + WARN_ON_ONCE(1); +} + +int nf_nat_icmp_reply_translation(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int hooknum) +{ + struct { + struct icmphdr icmp; + struct iphdr ip; + } *inside; + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + enum nf_nat_manip_type manip = HOOK2MANIP(hooknum); + unsigned int hdrlen = ip_hdrlen(skb); + struct nf_conntrack_tuple target; + unsigned long statusbit; + + WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY); + + if (!skb_make_writable(skb, hdrlen + sizeof(*inside))) + return 0; + if (nf_ip_checksum(skb, hooknum, hdrlen, 0)) + return 0; + + inside = (void *)skb->data + hdrlen; + if (inside->icmp.type == ICMP_REDIRECT) { + if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK) + return 0; + if (ct->status & IPS_NAT_MASK) + return 0; + } + + if (manip == NF_NAT_MANIP_SRC) + statusbit = IPS_SRC_NAT; + else + statusbit = IPS_DST_NAT; + + /* Invert if this is reply direction */ + if (dir == IP_CT_DIR_REPLY) + statusbit ^= IPS_NAT_MASK; + + if (!(ct->status & statusbit)) + return 1; + + if (!nf_nat_ipv4_manip_pkt(skb, hdrlen + sizeof(inside->icmp), + &ct->tuplehash[!dir].tuple, !manip)) + return 0; + + if (skb->ip_summed != CHECKSUM_PARTIAL) { + /* Reloading "inside" here since manip_pkt may reallocate */ + inside = (void *)skb->data + hdrlen; + inside->icmp.checksum = 0; + inside->icmp.checksum = + csum_fold(skb_checksum(skb, hdrlen, + skb->len - hdrlen, 0)); + } + + /* Change outer to look like the reply to an incoming packet */ + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + target.dst.protonum = IPPROTO_ICMP; + if (!nf_nat_ipv4_manip_pkt(skb, 0, &target, manip)) + return 0; + + return 1; +} +EXPORT_SYMBOL_GPL(nf_nat_icmp_reply_translation); + +static unsigned int +nf_nat_ipv4_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return NF_ACCEPT; + + if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) { + if (ip_hdr(skb)->protocol == IPPROTO_ICMP) { + if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, + state->hook)) + return NF_DROP; + else + return NF_ACCEPT; + } + } + + return nf_nat_inet_fn(priv, skb, state); +} + +static unsigned int +nf_nat_ipv4_in(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + unsigned int ret; + __be32 daddr = ip_hdr(skb)->daddr; + + ret = nf_nat_ipv4_fn(priv, skb, state); + if (ret == NF_ACCEPT && daddr != ip_hdr(skb)->daddr) + skb_dst_drop(skb); + + return ret; +} + +static unsigned int +nf_nat_ipv4_out(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ +#ifdef CONFIG_XFRM + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + int err; +#endif + unsigned int ret; + + ret = nf_nat_ipv4_fn(priv, skb, state); +#ifdef CONFIG_XFRM + if (ret != NF_ACCEPT) + return ret; + + if (IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) + return ret; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (ct->tuplehash[dir].tuple.src.u3.ip != + ct->tuplehash[!dir].tuple.dst.u3.ip || + (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP && + ct->tuplehash[dir].tuple.src.u.all != + ct->tuplehash[!dir].tuple.dst.u.all)) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET); + if (err < 0) + ret = NF_DROP_ERR(err); + } + } +#endif + return ret; +} + +static unsigned int +nf_nat_ipv4_local_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + unsigned int ret; + int err; + + ret = nf_nat_ipv4_fn(priv, skb, state); + if (ret != NF_ACCEPT) + return ret; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (ct->tuplehash[dir].tuple.dst.u3.ip != + ct->tuplehash[!dir].tuple.src.u3.ip) { + err = ip_route_me_harder(state->net, skb, RTN_UNSPEC); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#ifdef CONFIG_XFRM + else if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) && + ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP && + ct->tuplehash[dir].tuple.dst.u.all != + ct->tuplehash[!dir].tuple.src.u.all) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#endif + } + return ret; +} + +static const struct nf_hook_ops nf_nat_ipv4_ops[] = { + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv4_in, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv4_out, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP_PRI_NAT_SRC, + }, + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv4_local_fn, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv4_fn, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_NAT_SRC, + }, +}; + +int nf_nat_ipv4_register_fn(struct net *net, const struct nf_hook_ops *ops) +{ + return nf_nat_register_fn(net, ops, nf_nat_ipv4_ops, ARRAY_SIZE(nf_nat_ipv4_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv4_register_fn); + +void nf_nat_ipv4_unregister_fn(struct net *net, const struct nf_hook_ops *ops) +{ + nf_nat_unregister_fn(net, ops, ARRAY_SIZE(nf_nat_ipv4_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv4_unregister_fn); + +#if IS_ENABLED(CONFIG_IPV6) +int nf_nat_icmpv6_reply_translation(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int hooknum, + unsigned int hdrlen) +{ + struct { + struct icmp6hdr icmp6; + struct ipv6hdr ip6; + } *inside; + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + enum nf_nat_manip_type manip = HOOK2MANIP(hooknum); + struct nf_conntrack_tuple target; + unsigned long statusbit; + + WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY); + + if (!skb_make_writable(skb, hdrlen + sizeof(*inside))) + return 0; + if (nf_ip6_checksum(skb, hooknum, hdrlen, IPPROTO_ICMPV6)) + return 0; + + inside = (void *)skb->data + hdrlen; + if (inside->icmp6.icmp6_type == NDISC_REDIRECT) { + if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK) + return 0; + if (ct->status & IPS_NAT_MASK) + return 0; + } + + if (manip == NF_NAT_MANIP_SRC) + statusbit = IPS_SRC_NAT; + else + statusbit = IPS_DST_NAT; + + /* Invert if this is reply direction */ + if (dir == IP_CT_DIR_REPLY) + statusbit ^= IPS_NAT_MASK; + + if (!(ct->status & statusbit)) + return 1; + + if (!nf_nat_ipv6_manip_pkt(skb, hdrlen + sizeof(inside->icmp6), + &ct->tuplehash[!dir].tuple, !manip)) + return 0; + + if (skb->ip_summed != CHECKSUM_PARTIAL) { + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + + inside = (void *)skb->data + hdrlen; + inside->icmp6.icmp6_cksum = 0; + inside->icmp6.icmp6_cksum = + csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + skb->len - hdrlen, IPPROTO_ICMPV6, + skb_checksum(skb, hdrlen, + skb->len - hdrlen, 0)); + } + + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + target.dst.protonum = IPPROTO_ICMPV6; + if (!nf_nat_ipv6_manip_pkt(skb, 0, &target, manip)) + return 0; + + return 1; +} +EXPORT_SYMBOL_GPL(nf_nat_icmpv6_reply_translation); + +static unsigned int +nf_nat_ipv6_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + __be16 frag_off; + int hdrlen; + u8 nexthdr; + + ct = nf_ct_get(skb, &ctinfo); + /* Can't track? It's not due to stress, or conntrack would + * have dropped it. Hence it's the user's responsibilty to + * packet filter it out, or implement conntrack/NAT for that + * protocol. 8) --RR + */ + if (!ct) + return NF_ACCEPT; + + if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) { + nexthdr = ipv6_hdr(skb)->nexthdr; + hdrlen = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), + &nexthdr, &frag_off); + + if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { + if (!nf_nat_icmpv6_reply_translation(skb, ct, ctinfo, + state->hook, + hdrlen)) + return NF_DROP; + else + return NF_ACCEPT; + } + } + + return nf_nat_inet_fn(priv, skb, state); +} + +static unsigned int +nf_nat_ipv6_in(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + unsigned int ret; + struct in6_addr daddr = ipv6_hdr(skb)->daddr; + + ret = nf_nat_ipv6_fn(priv, skb, state); + if (ret != NF_DROP && ret != NF_STOLEN && + ipv6_addr_cmp(&daddr, &ipv6_hdr(skb)->daddr)) + skb_dst_drop(skb); + + return ret; +} + +static unsigned int +nf_nat_ipv6_out(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ +#ifdef CONFIG_XFRM + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + int err; +#endif + unsigned int ret; + + ret = nf_nat_ipv6_fn(priv, skb, state); +#ifdef CONFIG_XFRM + if (ret != NF_ACCEPT) + return ret; + + if (IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) + return ret; + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3) || + (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 && + ct->tuplehash[dir].tuple.src.u.all != + ct->tuplehash[!dir].tuple.dst.u.all)) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET6); + if (err < 0) + ret = NF_DROP_ERR(err); + } + } +#endif + + return ret; +} + +static int nat_route_me_harder(struct net *net, struct sk_buff *skb) +{ +#ifdef CONFIG_IPV6_MODULE + const struct nf_ipv6_ops *v6_ops = nf_get_ipv6_ops(); + + if (!v6_ops) + return -EHOSTUNREACH; + + return v6_ops->route_me_harder(net, skb); +#else + return ip6_route_me_harder(net, skb); +#endif +} + +static unsigned int +nf_nat_ipv6_local_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + unsigned int ret; + int err; + + ret = nf_nat_ipv6_fn(priv, skb, state); + if (ret != NF_ACCEPT) + return ret; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3, + &ct->tuplehash[!dir].tuple.src.u3)) { + err = nat_route_me_harder(state->net, skb); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#ifdef CONFIG_XFRM + else if (!(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) && + ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 && + ct->tuplehash[dir].tuple.dst.u.all != + ct->tuplehash[!dir].tuple.src.u.all) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET6); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#endif + } + + return ret; +} + +static const struct nf_hook_ops nf_nat_ipv6_ops[] = { + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv6_in, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP6_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv6_out, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP6_PRI_NAT_SRC, + }, + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv6_local_fn, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv6_fn, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP6_PRI_NAT_SRC, + }, +}; + +int nf_nat_ipv6_register_fn(struct net *net, const struct nf_hook_ops *ops) +{ + return nf_nat_register_fn(net, ops, nf_nat_ipv6_ops, + ARRAY_SIZE(nf_nat_ipv6_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv6_register_fn); + +void nf_nat_ipv6_unregister_fn(struct net *net, const struct nf_hook_ops *ops) +{ + nf_nat_unregister_fn(net, ops, ARRAY_SIZE(nf_nat_ipv6_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv6_unregister_fn); +#endif /* CONFIG_IPV6 */ diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c index fb07f6cfc719..513f93118604 100644 --- a/net/netfilter/nf_tables_api.c +++ b/net/netfilter/nf_tables_api.c @@ -37,10 +37,16 @@ enum { NFT_VALIDATE_DO, }; +static struct rhltable nft_objname_ht; + static u32 nft_chain_hash(const void *data, u32 len, u32 seed); static u32 nft_chain_hash_obj(const void *data, u32 len, u32 seed); static int nft_chain_hash_cmp(struct rhashtable_compare_arg *, const void *); +static u32 nft_objname_hash(const void *data, u32 len, u32 seed); +static u32 nft_objname_hash_obj(const void *data, u32 len, u32 seed); +static int nft_objname_hash_cmp(struct rhashtable_compare_arg *, const void *); + static const struct rhashtable_params nft_chain_ht_params = { .head_offset = offsetof(struct nft_chain, rhlhead), .key_offset = offsetof(struct nft_chain, name), @@ -51,6 +57,15 @@ static const struct rhashtable_params nft_chain_ht_params = { .automatic_shrinking = true, }; +static const struct rhashtable_params nft_objname_ht_params = { + .head_offset = offsetof(struct nft_object, rhlhead), + .key_offset = offsetof(struct nft_object, key), + .hashfn = nft_objname_hash, + .obj_hashfn = nft_objname_hash_obj, + .obj_cmpfn = nft_objname_hash_cmp, + .automatic_shrinking = true, +}; + static void nft_validate_state_update(struct net *net, u8 new_validate_state) { switch (net->nft.validate_state) { @@ -116,6 +131,23 @@ static void nft_trans_destroy(struct nft_trans *trans) kfree(trans); } +static void nft_set_trans_bind(const struct nft_ctx *ctx, struct nft_set *set) +{ + struct net *net = ctx->net; + struct nft_trans *trans; + + if (!nft_set_is_anonymous(set)) + return; + + list_for_each_entry_reverse(trans, &net->nft.commit_list, list) { + if (trans->msg_type == NFT_MSG_NEWSET && + nft_trans_set(trans) == set) { + set->bound = true; + break; + } + } +} + static int nf_tables_register_hook(struct net *net, const struct nft_table *table, struct nft_chain *chain) @@ -211,18 +243,6 @@ static int nft_delchain(struct nft_ctx *ctx) return err; } -/* either expr ops provide both activate/deactivate, or neither */ -static bool nft_expr_check_ops(const struct nft_expr_ops *ops) -{ - if (!ops) - return true; - - if (WARN_ON_ONCE((!ops->activate ^ !ops->deactivate))) - return false; - - return true; -} - static void nft_rule_expr_activate(const struct nft_ctx *ctx, struct nft_rule *rule) { @@ -238,14 +258,15 @@ static void nft_rule_expr_activate(const struct nft_ctx *ctx, } static void nft_rule_expr_deactivate(const struct nft_ctx *ctx, - struct nft_rule *rule) + struct nft_rule *rule, + enum nft_trans_phase phase) { struct nft_expr *expr; expr = nft_expr_first(rule); while (expr != nft_expr_last(rule) && expr->ops) { if (expr->ops->deactivate) - expr->ops->deactivate(ctx, expr); + expr->ops->deactivate(ctx, expr, phase); expr = nft_expr_next(expr); } @@ -296,7 +317,7 @@ static int nft_delrule(struct nft_ctx *ctx, struct nft_rule *rule) nft_trans_destroy(trans); return err; } - nft_rule_expr_deactivate(ctx, rule); + nft_rule_expr_deactivate(ctx, rule, NFT_TRANS_PREPARE); return 0; } @@ -307,6 +328,9 @@ static int nft_delrule_by_chain(struct nft_ctx *ctx) int err; list_for_each_entry(rule, &ctx->chain->rules, list) { + if (!nft_is_active_next(ctx->net, rule)) + continue; + err = nft_delrule(ctx, rule); if (err < 0) return err; @@ -814,6 +838,34 @@ static int nft_chain_hash_cmp(struct rhashtable_compare_arg *arg, return strcmp(chain->name, name); } +static u32 nft_objname_hash(const void *data, u32 len, u32 seed) +{ + const struct nft_object_hash_key *k = data; + + seed ^= hash_ptr(k->table, 32); + + return jhash(k->name, strlen(k->name), seed); +} + +static u32 nft_objname_hash_obj(const void *data, u32 len, u32 seed) +{ + const struct nft_object *obj = data; + + return nft_objname_hash(&obj->key, 0, seed); +} + +static int nft_objname_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct nft_object_hash_key *k = arg->key; + const struct nft_object *obj = ptr; + + if (obj->key.table != k->table) + return -1; + + return strcmp(obj->key.name, k->name); +} + static int nf_tables_newtable(struct net *net, struct sock *nlsk, struct sk_buff *skb, const struct nlmsghdr *nlh, const struct nlattr * const nla[], @@ -1070,7 +1122,7 @@ nft_chain_lookup_byhandle(const struct nft_table *table, u64 handle, u8 genmask) return ERR_PTR(-ENOENT); } -static bool lockdep_commit_lock_is_held(struct net *net) +static bool lockdep_commit_lock_is_held(const struct net *net) { #ifdef CONFIG_PROVE_LOCKING return lockdep_is_held(&net->nft.commit_mutex); @@ -1929,9 +1981,6 @@ static int nf_tables_delchain(struct net *net, struct sock *nlsk, */ int nft_register_expr(struct nft_expr_type *type) { - if (!nft_expr_check_ops(type->ops)) - return -EINVAL; - nfnl_lock(NFNL_SUBSYS_NFTABLES); if (type->family == NFPROTO_UNSPEC) list_add_tail_rcu(&type->list, &nf_tables_expressions); @@ -2079,10 +2128,6 @@ static int nf_tables_expr_parse(const struct nft_ctx *ctx, err = PTR_ERR(ops); goto err1; } - if (!nft_expr_check_ops(ops)) { - err = -EINVAL; - goto err1; - } } else ops = type->ops; @@ -2117,9 +2162,11 @@ err1: static void nf_tables_expr_destroy(const struct nft_ctx *ctx, struct nft_expr *expr) { + const struct nft_expr_type *type = expr->ops->type; + if (expr->ops->destroy) expr->ops->destroy(ctx, expr); - module_put(expr->ops->type->owner); + module_put(type->owner); } struct nft_expr *nft_expr_init(const struct nft_ctx *ctx, @@ -2127,6 +2174,7 @@ struct nft_expr *nft_expr_init(const struct nft_ctx *ctx, { struct nft_expr_info info; struct nft_expr *expr; + struct module *owner; int err; err = nf_tables_expr_parse(ctx, nla, &info); @@ -2146,7 +2194,11 @@ struct nft_expr *nft_expr_init(const struct nft_ctx *ctx, err3: kfree(expr); err2: - module_put(info.ops->type->owner); + owner = info.ops->type->owner; + if (info.ops->type->release_ops) + info.ops->type->release_ops(info.ops); + + module_put(owner); err1: return ERR_PTR(err); } @@ -2196,6 +2248,7 @@ static const struct nla_policy nft_rule_policy[NFTA_RULE_MAX + 1] = { [NFTA_RULE_USERDATA] = { .type = NLA_BINARY, .len = NFT_USERDATA_MAXLEN }, [NFTA_RULE_ID] = { .type = NLA_U32 }, + [NFTA_RULE_POSITION_ID] = { .type = NLA_U32 }, }; static int nf_tables_fill_rule_info(struct sk_buff *skb, struct net *net, @@ -2511,7 +2564,7 @@ static void nf_tables_rule_destroy(const struct nft_ctx *ctx, static void nf_tables_rule_release(const struct nft_ctx *ctx, struct nft_rule *rule) { - nft_rule_expr_deactivate(ctx, rule); + nft_rule_expr_deactivate(ctx, rule, NFT_TRANS_RELEASE); nf_tables_rule_destroy(ctx, rule); } @@ -2565,6 +2618,9 @@ static int nft_table_validate(struct net *net, const struct nft_table *table) return 0; } +static struct nft_rule *nft_rule_lookup_byid(const struct net *net, + const struct nlattr *nla); + #define NFT_RULE_MAXEXPRS 128 static int nf_tables_newrule(struct net *net, struct sock *nlsk, @@ -2634,6 +2690,12 @@ static int nf_tables_newrule(struct net *net, struct sock *nlsk, NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_POSITION]); return PTR_ERR(old_rule); } + } else if (nla[NFTA_RULE_POSITION_ID]) { + old_rule = nft_rule_lookup_byid(net, nla[NFTA_RULE_POSITION_ID]); + if (IS_ERR(old_rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_POSITION_ID]); + return PTR_ERR(old_rule); + } } } @@ -3612,6 +3674,9 @@ err1: static void nft_set_destroy(struct nft_set *set) { + if (WARN_ON(set->use > 0)) + return; + set->ops->destroy(set); module_put(to_set_type(set->ops)->owner); kfree(set->name); @@ -3652,7 +3717,7 @@ static int nf_tables_delset(struct net *net, struct sock *nlsk, NL_SET_BAD_ATTR(extack, attr); return PTR_ERR(set); } - if (!list_empty(&set->bindings) || + if (set->use || (nlh->nlmsg_flags & NLM_F_NONREC && atomic_read(&set->nelems) > 0)) { NL_SET_BAD_ATTR(extack, attr); return -EBUSY; @@ -3682,6 +3747,9 @@ int nf_tables_bind_set(const struct nft_ctx *ctx, struct nft_set *set, struct nft_set_binding *i; struct nft_set_iter iter; + if (set->use == UINT_MAX) + return -EOVERFLOW; + if (!list_empty(&set->bindings) && nft_set_is_anonymous(set)) return -EBUSY; @@ -3708,39 +3776,50 @@ int nf_tables_bind_set(const struct nft_ctx *ctx, struct nft_set *set, bind: binding->chain = ctx->chain; list_add_tail_rcu(&binding->list, &set->bindings); + nft_set_trans_bind(ctx, set); + set->use++; + return 0; } EXPORT_SYMBOL_GPL(nf_tables_bind_set); -void nf_tables_rebind_set(const struct nft_ctx *ctx, struct nft_set *set, - struct nft_set_binding *binding) -{ - if (list_empty(&set->bindings) && nft_set_is_anonymous(set) && - nft_is_active(ctx->net, set)) - list_add_tail_rcu(&set->list, &ctx->table->sets); - - list_add_tail_rcu(&binding->list, &set->bindings); -} -EXPORT_SYMBOL_GPL(nf_tables_rebind_set); - void nf_tables_unbind_set(const struct nft_ctx *ctx, struct nft_set *set, - struct nft_set_binding *binding) + struct nft_set_binding *binding, bool event) { list_del_rcu(&binding->list); - if (list_empty(&set->bindings) && nft_set_is_anonymous(set) && - nft_is_active(ctx->net, set)) + if (list_empty(&set->bindings) && nft_set_is_anonymous(set)) { list_del_rcu(&set->list); + if (event) + nf_tables_set_notify(ctx, set, NFT_MSG_DELSET, + GFP_KERNEL); + } } EXPORT_SYMBOL_GPL(nf_tables_unbind_set); +void nf_tables_deactivate_set(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_binding *binding, + enum nft_trans_phase phase) +{ + switch (phase) { + case NFT_TRANS_PREPARE: + set->use--; + return; + case NFT_TRANS_ABORT: + case NFT_TRANS_RELEASE: + set->use--; + /* fall through */ + default: + nf_tables_unbind_set(ctx, set, binding, + phase == NFT_TRANS_COMMIT); + } +} +EXPORT_SYMBOL_GPL(nf_tables_deactivate_set); + void nf_tables_destroy_set(const struct nft_ctx *ctx, struct nft_set *set) { - if (list_empty(&set->bindings) && nft_set_is_anonymous(set) && - nft_is_active(ctx->net, set)) { - nf_tables_set_notify(ctx, set, NFT_MSG_DELSET, GFP_ATOMIC); + if (list_empty(&set->bindings) && nft_set_is_anonymous(set)) nft_set_destroy(set); - } } EXPORT_SYMBOL_GPL(nf_tables_destroy_set); @@ -3851,7 +3930,7 @@ static int nf_tables_fill_setelem(struct sk_buff *skb, if (nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF) && nla_put_string(skb, NFTA_SET_ELEM_OBJREF, - (*nft_set_ext_obj(ext))->name) < 0) + (*nft_set_ext_obj(ext))->key.name) < 0) goto nla_put_failure; if (nft_set_ext_exists(ext, NFT_SET_EXT_FLAGS) && @@ -4384,7 +4463,8 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, err = -EINVAL; goto err2; } - obj = nft_obj_lookup(ctx->table, nla[NFTA_SET_ELEM_OBJREF], + obj = nft_obj_lookup(ctx->net, ctx->table, + nla[NFTA_SET_ELEM_OBJREF], set->objtype, genmask); if (IS_ERR(obj)) { err = PTR_ERR(obj); @@ -4819,18 +4899,36 @@ void nft_unregister_obj(struct nft_object_type *obj_type) } EXPORT_SYMBOL_GPL(nft_unregister_obj); -struct nft_object *nft_obj_lookup(const struct nft_table *table, +struct nft_object *nft_obj_lookup(const struct net *net, + const struct nft_table *table, const struct nlattr *nla, u32 objtype, u8 genmask) { + struct nft_object_hash_key k = { .table = table }; + char search[NFT_OBJ_MAXNAMELEN]; + struct rhlist_head *tmp, *list; struct nft_object *obj; - list_for_each_entry_rcu(obj, &table->objects, list) { - if (!nla_strcmp(nla, obj->name) && - objtype == obj->ops->type->type && - nft_active_genmask(obj, genmask)) + nla_strlcpy(search, nla, sizeof(search)); + k.name = search; + + WARN_ON_ONCE(!rcu_read_lock_held() && + !lockdep_commit_lock_is_held(net)); + + rcu_read_lock(); + list = rhltable_lookup(&nft_objname_ht, &k, nft_objname_ht_params); + if (!list) + goto out; + + rhl_for_each_entry_rcu(obj, tmp, list, rhlhead) { + if (objtype == obj->ops->type->type && + nft_active_genmask(obj, genmask)) { + rcu_read_unlock(); return obj; + } } +out: + rcu_read_unlock(); return ERR_PTR(-ENOENT); } EXPORT_SYMBOL_GPL(nft_obj_lookup); @@ -4988,7 +5086,7 @@ static int nf_tables_newobj(struct net *net, struct sock *nlsk, } objtype = ntohl(nla_get_be32(nla[NFTA_OBJ_TYPE])); - obj = nft_obj_lookup(table, nla[NFTA_OBJ_NAME], objtype, genmask); + obj = nft_obj_lookup(net, table, nla[NFTA_OBJ_NAME], objtype, genmask); if (IS_ERR(obj)) { err = PTR_ERR(obj); if (err != -ENOENT) { @@ -5014,11 +5112,11 @@ static int nf_tables_newobj(struct net *net, struct sock *nlsk, err = PTR_ERR(obj); goto err1; } - obj->table = table; + obj->key.table = table; obj->handle = nf_tables_alloc_handle(table); - obj->name = nla_strdup(nla[NFTA_OBJ_NAME], GFP_KERNEL); - if (!obj->name) { + obj->key.name = nla_strdup(nla[NFTA_OBJ_NAME], GFP_KERNEL); + if (!obj->key.name) { err = -ENOMEM; goto err2; } @@ -5027,11 +5125,20 @@ static int nf_tables_newobj(struct net *net, struct sock *nlsk, if (err < 0) goto err3; + err = rhltable_insert(&nft_objname_ht, &obj->rhlhead, + nft_objname_ht_params); + if (err < 0) + goto err4; + list_add_tail_rcu(&obj->list, &table->objects); table->use++; return 0; +err4: + /* queued in transaction log */ + INIT_LIST_HEAD(&obj->list); + return err; err3: - kfree(obj->name); + kfree(obj->key.name); err2: if (obj->ops->destroy) obj->ops->destroy(&ctx, obj); @@ -5060,7 +5167,7 @@ static int nf_tables_fill_obj_info(struct sk_buff *skb, struct net *net, nfmsg->res_id = htons(net->nft.base_seq & 0xffff); if (nla_put_string(skb, NFTA_OBJ_TABLE, table->name) || - nla_put_string(skb, NFTA_OBJ_NAME, obj->name) || + nla_put_string(skb, NFTA_OBJ_NAME, obj->key.name) || nla_put_be32(skb, NFTA_OBJ_TYPE, htonl(obj->ops->type->type)) || nla_put_be32(skb, NFTA_OBJ_USE, htonl(obj->use)) || nft_object_dump(skb, NFTA_OBJ_DATA, obj, reset) || @@ -5215,7 +5322,7 @@ static int nf_tables_getobj(struct net *net, struct sock *nlsk, } objtype = ntohl(nla_get_be32(nla[NFTA_OBJ_TYPE])); - obj = nft_obj_lookup(table, nla[NFTA_OBJ_NAME], objtype, genmask); + obj = nft_obj_lookup(net, table, nla[NFTA_OBJ_NAME], objtype, genmask); if (IS_ERR(obj)) { NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_NAME]); return PTR_ERR(obj); @@ -5246,7 +5353,7 @@ static void nft_obj_destroy(const struct nft_ctx *ctx, struct nft_object *obj) obj->ops->destroy(ctx, obj); module_put(obj->ops->type->owner); - kfree(obj->name); + kfree(obj->key.name); kfree(obj); } @@ -5280,7 +5387,7 @@ static int nf_tables_delobj(struct net *net, struct sock *nlsk, obj = nft_obj_lookup_byhandle(table, attr, objtype, genmask); } else { attr = nla[NFTA_OBJ_NAME]; - obj = nft_obj_lookup(table, attr, objtype, genmask); + obj = nft_obj_lookup(net, table, attr, objtype, genmask); } if (IS_ERR(obj)) { @@ -5297,7 +5404,7 @@ static int nf_tables_delobj(struct net *net, struct sock *nlsk, return nft_delobj(&ctx, obj); } -void nft_obj_notify(struct net *net, struct nft_table *table, +void nft_obj_notify(struct net *net, const struct nft_table *table, struct nft_object *obj, u32 portid, u32 seq, int event, int family, int report, gfp_t gfp) { @@ -6404,6 +6511,12 @@ static void nf_tables_commit_chain(struct net *net, struct nft_chain *chain) nf_tables_commit_chain_free_rules_old(g0); } +static void nft_obj_del(struct nft_object *obj) +{ + rhltable_remove(&nft_objname_ht, &obj->rhlhead, nft_objname_ht_params); + list_del_rcu(&obj->list); +} + static void nft_chain_del(struct nft_chain *chain) { struct nft_table *table = chain->table; @@ -6451,6 +6564,11 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb) struct nft_chain *chain; struct nft_table *table; + if (list_empty(&net->nft.commit_list)) { + mutex_unlock(&net->nft.commit_mutex); + return 0; + } + /* 0. Validate ruleset, otherwise roll back for error reporting. */ if (nf_tables_validate(net) < 0) return -EAGAIN; @@ -6535,6 +6653,9 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb) nf_tables_rule_notify(&trans->ctx, nft_trans_rule(trans), NFT_MSG_DELRULE); + nft_rule_expr_deactivate(&trans->ctx, + nft_trans_rule(trans), + NFT_TRANS_COMMIT); break; case NFT_MSG_NEWSET: nft_clear(net, nft_trans_set(trans)); @@ -6580,7 +6701,7 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb) nft_trans_destroy(trans); break; case NFT_MSG_DELOBJ: - list_del_rcu(&nft_trans_obj(trans)->list); + nft_obj_del(nft_trans_obj(trans)); nf_tables_obj_notify(&trans->ctx, nft_trans_obj(trans), NFT_MSG_DELOBJ); break; @@ -6682,7 +6803,9 @@ static int __nf_tables_abort(struct net *net) case NFT_MSG_NEWRULE: trans->ctx.chain->use--; list_del_rcu(&nft_trans_rule(trans)->list); - nft_rule_expr_deactivate(&trans->ctx, nft_trans_rule(trans)); + nft_rule_expr_deactivate(&trans->ctx, + nft_trans_rule(trans), + NFT_TRANS_ABORT); break; case NFT_MSG_DELRULE: trans->ctx.chain->use++; @@ -6692,6 +6815,10 @@ static int __nf_tables_abort(struct net *net) break; case NFT_MSG_NEWSET: trans->ctx.table->use--; + if (nft_trans_set(trans)->bound) { + nft_trans_destroy(trans); + break; + } list_del_rcu(&nft_trans_set(trans)->list); break; case NFT_MSG_DELSET: @@ -6700,8 +6827,11 @@ static int __nf_tables_abort(struct net *net) nft_trans_destroy(trans); break; case NFT_MSG_NEWSETELEM: + if (nft_trans_elem_set(trans)->bound) { + nft_trans_destroy(trans); + break; + } te = (struct nft_trans_elem *)trans->data; - te->set->ops->remove(net, te->set, &te->elem); atomic_dec(&te->set->nelems); break; @@ -6716,7 +6846,7 @@ static int __nf_tables_abort(struct net *net) break; case NFT_MSG_NEWOBJ: trans->ctx.table->use--; - list_del_rcu(&nft_trans_obj(trans)->list); + nft_obj_del(nft_trans_obj(trans)); break; case NFT_MSG_DELOBJ: trans->ctx.table->use++; @@ -7330,7 +7460,7 @@ static void __nft_release_tables(struct net *net) nft_set_destroy(set); } list_for_each_entry_safe(obj, ne, &table->objects, list) { - list_del(&obj->list); + nft_obj_del(obj); table->use--; nft_obj_destroy(&ctx, obj); } @@ -7392,12 +7522,18 @@ static int __init nf_tables_module_init(void) if (err < 0) goto err3; + err = rhltable_init(&nft_objname_ht, &nft_objname_ht_params); + if (err < 0) + goto err4; + /* must be last */ err = nfnetlink_subsys_register(&nf_tables_subsys); if (err < 0) - goto err4; + goto err5; return err; +err5: + rhltable_destroy(&nft_objname_ht); err4: unregister_netdevice_notifier(&nf_tables_flowtable_notifier); err3: @@ -7417,6 +7553,7 @@ static void __exit nf_tables_module_exit(void) unregister_pernet_subsys(&nf_tables_net_ops); cancel_work_sync(&trans_destroy_work); rcu_barrier(); + rhltable_destroy(&nft_objname_ht); nf_tables_core_module_exit(); } diff --git a/net/netfilter/nf_tables_core.c b/net/netfilter/nf_tables_core.c index a50500232b0a..d0f168c2670f 100644 --- a/net/netfilter/nf_tables_core.c +++ b/net/netfilter/nf_tables_core.c @@ -98,21 +98,23 @@ static noinline void nft_update_chain_stats(const struct nft_chain *chain, const struct nft_pktinfo *pkt) { struct nft_base_chain *base_chain; + struct nft_stats __percpu *pstats; struct nft_stats *stats; base_chain = nft_base_chain(chain); - if (!rcu_access_pointer(base_chain->stats)) - return; - local_bh_disable(); - stats = this_cpu_ptr(rcu_dereference(base_chain->stats)); - if (stats) { + rcu_read_lock(); + pstats = READ_ONCE(base_chain->stats); + if (pstats) { + local_bh_disable(); + stats = this_cpu_ptr(pstats); u64_stats_update_begin(&stats->syncp); stats->pkts++; stats->bytes += pkt->skb->len; u64_stats_update_end(&stats->syncp); + local_bh_enable(); } - local_bh_enable(); + rcu_read_unlock(); } struct nft_jumpstack { @@ -124,14 +126,25 @@ static void expr_call_ops_eval(const struct nft_expr *expr, struct nft_regs *regs, struct nft_pktinfo *pkt) { +#ifdef CONFIG_RETPOLINE unsigned long e = (unsigned long)expr->ops->eval; - - if (e == (unsigned long)nft_meta_get_eval) - nft_meta_get_eval(expr, regs, pkt); - else if (e == (unsigned long)nft_lookup_eval) - nft_lookup_eval(expr, regs, pkt); - else - expr->ops->eval(expr, regs, pkt); +#define X(e, fun) \ + do { if ((e) == (unsigned long)(fun)) \ + return fun(expr, regs, pkt); } while (0) + + X(e, nft_payload_eval); + X(e, nft_cmp_eval); + X(e, nft_meta_get_eval); + X(e, nft_lookup_eval); + X(e, nft_range_eval); + X(e, nft_immediate_eval); + X(e, nft_byteorder_eval); + X(e, nft_dynset_eval); + X(e, nft_rt_get_eval); + X(e, nft_bitwise_eval); +#undef X +#endif /* CONFIG_RETPOLINE */ + expr->ops->eval(expr, regs, pkt); } unsigned int @@ -210,7 +223,6 @@ next_rule: chain = regs.verdict.chain; goto do_chain; case NFT_CONTINUE: - /* fall through */ case NFT_RETURN: nft_trace_packet(&info, chain, rule, NFT_TRACETYPE_RETURN); diff --git a/net/netfilter/nfnetlink_cttimeout.c b/net/netfilter/nfnetlink_cttimeout.c index 109b0d27345a..c69b11ca5aad 100644 --- a/net/netfilter/nfnetlink_cttimeout.c +++ b/net/netfilter/nfnetlink_cttimeout.c @@ -122,7 +122,7 @@ static int cttimeout_new_timeout(struct net *net, struct sock *ctnl, return -EBUSY; } - l4proto = nf_ct_l4proto_find_get(l4num); + l4proto = nf_ct_l4proto_find(l4num); /* This protocol is not supportted, skip. */ if (l4proto->l4proto != l4num) { @@ -152,7 +152,6 @@ static int cttimeout_new_timeout(struct net *net, struct sock *ctnl, err: kfree(timeout); err_proto_put: - nf_ct_l4proto_put(l4proto); return ret; } @@ -302,7 +301,6 @@ static int ctnl_timeout_try_del(struct net *net, struct ctnl_timeout *timeout) if (refcount_dec_if_one(&timeout->refcnt)) { /* We are protected by nfnl mutex. */ list_del_rcu(&timeout->head); - nf_ct_l4proto_put(timeout->timeout.l4proto); nf_ct_untimeout(net, &timeout->timeout); kfree_rcu(timeout, rcu_head); } else { @@ -359,7 +357,7 @@ static int cttimeout_default_set(struct net *net, struct sock *ctnl, return -EINVAL; l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); - l4proto = nf_ct_l4proto_find_get(l4num); + l4proto = nf_ct_l4proto_find(l4num); /* This protocol is not supported, skip. */ if (l4proto->l4proto != l4num) { @@ -372,10 +370,8 @@ static int cttimeout_default_set(struct net *net, struct sock *ctnl, if (ret < 0) goto err; - nf_ct_l4proto_put(l4proto); return 0; err: - nf_ct_l4proto_put(l4proto); return ret; } @@ -442,7 +438,7 @@ static int cttimeout_default_get(struct net *net, struct sock *ctnl, l3num = ntohs(nla_get_be16(cda[CTA_TIMEOUT_L3PROTO])); l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); - l4proto = nf_ct_l4proto_find_get(l4num); + l4proto = nf_ct_l4proto_find(l4num); err = -EOPNOTSUPP; if (l4proto->l4proto != l4num) @@ -474,12 +470,7 @@ static int cttimeout_default_get(struct net *net, struct sock *ctnl, break; case IPPROTO_GRE: #ifdef CONFIG_NF_CT_PROTO_GRE - if (l4proto->net_id) { - struct netns_proto_gre *net_gre; - - net_gre = net_generic(net, *l4proto->net_id); - timeouts = net_gre->gre_timeouts; - } + timeouts = nf_gre_pernet(net)->timeouts; #endif break; case 255: @@ -516,7 +507,6 @@ static int cttimeout_default_get(struct net *net, struct sock *ctnl, /* this avoids a loop in nfnetlink. */ return ret == -EAGAIN ? -ENOBUFS : ret; err: - nf_ct_l4proto_put(l4proto); return err; } @@ -597,7 +587,6 @@ static void __net_exit cttimeout_net_exit(struct net *net) list_for_each_entry_safe(cur, tmp, &net->nfct_timeout_list, head) { list_del_rcu(&cur->head); - nf_ct_l4proto_put(cur->timeout.l4proto); if (refcount_dec_and_test(&cur->refcnt)) kfree_rcu(cur, rcu_head); diff --git a/net/netfilter/nfnetlink_osf.c b/net/netfilter/nfnetlink_osf.c index 6f41dd74729d..1f1d90c1716b 100644 --- a/net/netfilter/nfnetlink_osf.c +++ b/net/netfilter/nfnetlink_osf.c @@ -66,6 +66,7 @@ static bool nf_osf_match_one(const struct sk_buff *skb, int ttl_check, struct nf_osf_hdr_ctx *ctx) { + const __u8 *optpinit = ctx->optp; unsigned int check_WSS = 0; int fmatch = FMATCH_WRONG; int foptsize, optnum; @@ -155,6 +156,9 @@ static bool nf_osf_match_one(const struct sk_buff *skb, } } + if (fmatch != FMATCH_OK) + ctx->optp = optpinit; + return fmatch == FMATCH_OK; } diff --git a/net/netfilter/nft_bitwise.c b/net/netfilter/nft_bitwise.c index fff8073e2a56..2c75b9e0474e 100644 --- a/net/netfilter/nft_bitwise.c +++ b/net/netfilter/nft_bitwise.c @@ -25,9 +25,8 @@ struct nft_bitwise { struct nft_data xor; }; -static void nft_bitwise_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_bitwise_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) { const struct nft_bitwise *priv = nft_expr_priv(expr); const u32 *src = ®s->data[priv->sreg]; diff --git a/net/netfilter/nft_byteorder.c b/net/netfilter/nft_byteorder.c index 13d4e421a6b3..19dbc34cc75e 100644 --- a/net/netfilter/nft_byteorder.c +++ b/net/netfilter/nft_byteorder.c @@ -26,9 +26,9 @@ struct nft_byteorder { u8 size; }; -static void nft_byteorder_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_byteorder_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) { const struct nft_byteorder *priv = nft_expr_priv(expr); u32 *src = ®s->data[priv->sreg]; diff --git a/net/netfilter/nft_chain_nat.c b/net/netfilter/nft_chain_nat.c new file mode 100644 index 000000000000..ee4852088d50 --- /dev/null +++ b/net/netfilter/nft_chain_nat.c @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <linux/module.h> +#include <linux/netfilter/nf_tables.h> +#include <net/netfilter/nf_nat.h> +#include <net/netfilter/nf_tables.h> +#include <net/netfilter/nf_tables_ipv4.h> +#include <net/netfilter/nf_tables_ipv6.h> + +static unsigned int nft_nat_do_chain(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + + switch (state->pf) { +#ifdef CONFIG_NF_TABLES_IPV4 + case NFPROTO_IPV4: + nft_set_pktinfo_ipv4(&pkt, skb); + break; +#endif +#ifdef CONFIG_NF_TABLES_IPV6 + case NFPROTO_IPV6: + nft_set_pktinfo_ipv6(&pkt, skb); + break; +#endif + default: + break; + } + + return nft_do_chain(&pkt, priv); +} + +#ifdef CONFIG_NF_TABLES_IPV4 +static const struct nft_chain_type nft_chain_nat_ipv4 = { + .name = "nat", + .type = NFT_CHAIN_T_NAT, + .family = NFPROTO_IPV4, + .owner = THIS_MODULE, + .hook_mask = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .hooks = { + [NF_INET_PRE_ROUTING] = nft_nat_do_chain, + [NF_INET_POST_ROUTING] = nft_nat_do_chain, + [NF_INET_LOCAL_OUT] = nft_nat_do_chain, + [NF_INET_LOCAL_IN] = nft_nat_do_chain, + }, + .ops_register = nf_nat_ipv4_register_fn, + .ops_unregister = nf_nat_ipv4_unregister_fn, +}; +#endif + +#ifdef CONFIG_NF_TABLES_IPV6 +static const struct nft_chain_type nft_chain_nat_ipv6 = { + .name = "nat", + .type = NFT_CHAIN_T_NAT, + .family = NFPROTO_IPV6, + .owner = THIS_MODULE, + .hook_mask = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .hooks = { + [NF_INET_PRE_ROUTING] = nft_nat_do_chain, + [NF_INET_POST_ROUTING] = nft_nat_do_chain, + [NF_INET_LOCAL_OUT] = nft_nat_do_chain, + [NF_INET_LOCAL_IN] = nft_nat_do_chain, + }, + .ops_register = nf_nat_ipv6_register_fn, + .ops_unregister = nf_nat_ipv6_unregister_fn, +}; +#endif + +static int __init nft_chain_nat_init(void) +{ +#ifdef CONFIG_NF_TABLES_IPV6 + nft_register_chain_type(&nft_chain_nat_ipv6); +#endif +#ifdef CONFIG_NF_TABLES_IPV4 + nft_register_chain_type(&nft_chain_nat_ipv4); +#endif + + return 0; +} + +static void __exit nft_chain_nat_exit(void) +{ +#ifdef CONFIG_NF_TABLES_IPV4 + nft_unregister_chain_type(&nft_chain_nat_ipv4); +#endif +#ifdef CONFIG_NF_TABLES_IPV6 + nft_unregister_chain_type(&nft_chain_nat_ipv6); +#endif +} + +module_init(nft_chain_nat_init); +module_exit(nft_chain_nat_exit); + +MODULE_LICENSE("GPL"); +#ifdef CONFIG_NF_TABLES_IPV4 +MODULE_ALIAS_NFT_CHAIN(AF_INET, "nat"); +#endif +#ifdef CONFIG_NF_TABLES_IPV6 +MODULE_ALIAS_NFT_CHAIN(AF_INET6, "nat"); +#endif diff --git a/net/netfilter/nft_cmp.c b/net/netfilter/nft_cmp.c index 79d48c1d06f4..f9f1fa66a16e 100644 --- a/net/netfilter/nft_cmp.c +++ b/net/netfilter/nft_cmp.c @@ -24,9 +24,9 @@ struct nft_cmp_expr { enum nft_cmp_ops op:8; }; -static void nft_cmp_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_cmp_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) { const struct nft_cmp_expr *priv = nft_expr_priv(expr); int d; diff --git a/net/netfilter/nft_compat.c b/net/netfilter/nft_compat.c index 7334e0b80a5e..469f9da5073b 100644 --- a/net/netfilter/nft_compat.c +++ b/net/netfilter/nft_compat.c @@ -23,19 +23,6 @@ #include <linux/netfilter_arp/arp_tables.h> #include <net/netfilter/nf_tables.h> -struct nft_xt { - struct list_head head; - struct nft_expr_ops ops; - unsigned int refcnt; - - /* Unlike other expressions, ops doesn't have static storage duration. - * nft core assumes they do. We use kfree_rcu so that nft core can - * can check expr->ops->size even after nft_compat->destroy() frees - * the nft_xt struct that holds the ops structure. - */ - struct rcu_head rcu_head; -}; - /* Used for matches where *info is larger than X byte */ #define NFT_MATCH_LARGE_THRESH 192 @@ -43,17 +30,6 @@ struct nft_xt_match_priv { void *info; }; -static bool nft_xt_put(struct nft_xt *xt) -{ - if (--xt->refcnt == 0) { - list_del(&xt->head); - kfree_rcu(xt, rcu_head); - return true; - } - - return false; -} - static int nft_compat_chain_validate_dependency(const struct nft_ctx *ctx, const char *tablename) { @@ -248,7 +224,6 @@ nft_target_init(const struct nft_ctx *ctx, const struct nft_expr *expr, struct xt_target *target = expr->ops->data; struct xt_tgchk_param par; size_t size = XT_ALIGN(nla_len(tb[NFTA_TARGET_INFO])); - struct nft_xt *nft_xt; u16 proto = 0; bool inv = false; union nft_entry e = {}; @@ -272,8 +247,6 @@ nft_target_init(const struct nft_ctx *ctx, const struct nft_expr *expr, if (!target->target) return -EINVAL; - nft_xt = container_of(expr->ops, struct nft_xt, ops); - nft_xt->refcnt++; return 0; } @@ -282,6 +255,7 @@ nft_target_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) { struct xt_target *target = expr->ops->data; void *info = nft_expr_priv(expr); + struct module *me = target->me; struct xt_tgdtor_param par; par.net = ctx->net; @@ -291,8 +265,8 @@ nft_target_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) if (par.target->destroy != NULL) par.target->destroy(&par); - if (nft_xt_put(container_of(expr->ops, struct nft_xt, ops))) - module_put(target->me); + module_put(me); + kfree(expr->ops); } static int nft_extension_dump_info(struct sk_buff *skb, int attr, @@ -465,7 +439,6 @@ __nft_match_init(const struct nft_ctx *ctx, const struct nft_expr *expr, struct xt_match *match = expr->ops->data; struct xt_mtchk_param par; size_t size = XT_ALIGN(nla_len(tb[NFTA_MATCH_INFO])); - struct nft_xt *nft_xt; u16 proto = 0; bool inv = false; union nft_entry e = {}; @@ -481,13 +454,7 @@ __nft_match_init(const struct nft_ctx *ctx, const struct nft_expr *expr, nft_match_set_mtchk_param(&par, ctx, match, info, &e, proto, inv); - ret = xt_check_match(&par, size, proto, inv); - if (ret < 0) - return ret; - - nft_xt = container_of(expr->ops, struct nft_xt, ops); - nft_xt->refcnt++; - return 0; + return xt_check_match(&par, size, proto, inv); } static int @@ -530,8 +497,8 @@ __nft_match_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr, if (par.match->destroy != NULL) par.match->destroy(&par); - if (nft_xt_put(container_of(expr->ops, struct nft_xt, ops))) - module_put(me); + module_put(me); + kfree(expr->ops); } static void @@ -734,22 +701,13 @@ static const struct nfnetlink_subsystem nfnl_compat_subsys = { .cb = nfnl_nft_compat_cb, }; -static LIST_HEAD(nft_match_list); - static struct nft_expr_type nft_match_type; -static bool nft_match_cmp(const struct xt_match *match, - const char *name, u32 rev, u32 family) -{ - return strcmp(match->name, name) == 0 && match->revision == rev && - (match->family == NFPROTO_UNSPEC || match->family == family); -} - static const struct nft_expr_ops * nft_match_select_ops(const struct nft_ctx *ctx, const struct nlattr * const tb[]) { - struct nft_xt *nft_match; + struct nft_expr_ops *ops; struct xt_match *match; unsigned int matchsize; char *mt_name; @@ -765,14 +723,6 @@ nft_match_select_ops(const struct nft_ctx *ctx, rev = ntohl(nla_get_be32(tb[NFTA_MATCH_REV])); family = ctx->family; - /* Re-use the existing match if it's already loaded. */ - list_for_each_entry(nft_match, &nft_match_list, head) { - struct xt_match *match = nft_match->ops.data; - - if (nft_match_cmp(match, mt_name, rev, family)) - return &nft_match->ops; - } - match = xt_request_find_match(family, mt_name, rev); if (IS_ERR(match)) return ERR_PTR(-ENOENT); @@ -782,66 +732,62 @@ nft_match_select_ops(const struct nft_ctx *ctx, goto err; } - /* This is the first time we use this match, allocate operations */ - nft_match = kzalloc(sizeof(struct nft_xt), GFP_KERNEL); - if (nft_match == NULL) { + ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL); + if (!ops) { err = -ENOMEM; goto err; } - nft_match->refcnt = 0; - nft_match->ops.type = &nft_match_type; - nft_match->ops.eval = nft_match_eval; - nft_match->ops.init = nft_match_init; - nft_match->ops.destroy = nft_match_destroy; - nft_match->ops.dump = nft_match_dump; - nft_match->ops.validate = nft_match_validate; - nft_match->ops.data = match; + ops->type = &nft_match_type; + ops->eval = nft_match_eval; + ops->init = nft_match_init; + ops->destroy = nft_match_destroy; + ops->dump = nft_match_dump; + ops->validate = nft_match_validate; + ops->data = match; matchsize = NFT_EXPR_SIZE(XT_ALIGN(match->matchsize)); if (matchsize > NFT_MATCH_LARGE_THRESH) { matchsize = NFT_EXPR_SIZE(sizeof(struct nft_xt_match_priv)); - nft_match->ops.eval = nft_match_large_eval; - nft_match->ops.init = nft_match_large_init; - nft_match->ops.destroy = nft_match_large_destroy; - nft_match->ops.dump = nft_match_large_dump; + ops->eval = nft_match_large_eval; + ops->init = nft_match_large_init; + ops->destroy = nft_match_large_destroy; + ops->dump = nft_match_large_dump; } - nft_match->ops.size = matchsize; - - list_add(&nft_match->head, &nft_match_list); + ops->size = matchsize; - return &nft_match->ops; + return ops; err: module_put(match->me); return ERR_PTR(err); } +static void nft_match_release_ops(const struct nft_expr_ops *ops) +{ + struct xt_match *match = ops->data; + + module_put(match->me); + kfree(ops); +} + static struct nft_expr_type nft_match_type __read_mostly = { .name = "match", .select_ops = nft_match_select_ops, + .release_ops = nft_match_release_ops, .policy = nft_match_policy, .maxattr = NFTA_MATCH_MAX, .owner = THIS_MODULE, }; -static LIST_HEAD(nft_target_list); - static struct nft_expr_type nft_target_type; -static bool nft_target_cmp(const struct xt_target *tg, - const char *name, u32 rev, u32 family) -{ - return strcmp(tg->name, name) == 0 && tg->revision == rev && - (tg->family == NFPROTO_UNSPEC || tg->family == family); -} - static const struct nft_expr_ops * nft_target_select_ops(const struct nft_ctx *ctx, const struct nlattr * const tb[]) { - struct nft_xt *nft_target; + struct nft_expr_ops *ops; struct xt_target *target; char *tg_name; u32 rev, family; @@ -861,17 +807,6 @@ nft_target_select_ops(const struct nft_ctx *ctx, strcmp(tg_name, "standard") == 0) return ERR_PTR(-EINVAL); - /* Re-use the existing target if it's already loaded. */ - list_for_each_entry(nft_target, &nft_target_list, head) { - struct xt_target *target = nft_target->ops.data; - - if (!target->target) - continue; - - if (nft_target_cmp(target, tg_name, rev, family)) - return &nft_target->ops; - } - target = xt_request_find_target(family, tg_name, rev); if (IS_ERR(target)) return ERR_PTR(-ENOENT); @@ -886,38 +821,43 @@ nft_target_select_ops(const struct nft_ctx *ctx, goto err; } - /* This is the first time we use this target, allocate operations */ - nft_target = kzalloc(sizeof(struct nft_xt), GFP_KERNEL); - if (nft_target == NULL) { + ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL); + if (!ops) { err = -ENOMEM; goto err; } - nft_target->refcnt = 0; - nft_target->ops.type = &nft_target_type; - nft_target->ops.size = NFT_EXPR_SIZE(XT_ALIGN(target->targetsize)); - nft_target->ops.init = nft_target_init; - nft_target->ops.destroy = nft_target_destroy; - nft_target->ops.dump = nft_target_dump; - nft_target->ops.validate = nft_target_validate; - nft_target->ops.data = target; + ops->type = &nft_target_type; + ops->size = NFT_EXPR_SIZE(XT_ALIGN(target->targetsize)); + ops->init = nft_target_init; + ops->destroy = nft_target_destroy; + ops->dump = nft_target_dump; + ops->validate = nft_target_validate; + ops->data = target; if (family == NFPROTO_BRIDGE) - nft_target->ops.eval = nft_target_eval_bridge; + ops->eval = nft_target_eval_bridge; else - nft_target->ops.eval = nft_target_eval_xt; - - list_add(&nft_target->head, &nft_target_list); + ops->eval = nft_target_eval_xt; - return &nft_target->ops; + return ops; err: module_put(target->me); return ERR_PTR(err); } +static void nft_target_release_ops(const struct nft_expr_ops *ops) +{ + struct xt_target *target = ops->data; + + module_put(target->me); + kfree(ops); +} + static struct nft_expr_type nft_target_type __read_mostly = { .name = "target", .select_ops = nft_target_select_ops, + .release_ops = nft_target_release_ops, .policy = nft_target_policy, .maxattr = NFTA_TARGET_MAX, .owner = THIS_MODULE, @@ -942,7 +882,6 @@ static int __init nft_compat_module_init(void) } return ret; - err_target: nft_unregister_expr(&nft_target_type); err_match: @@ -952,32 +891,6 @@ err_match: static void __exit nft_compat_module_exit(void) { - struct nft_xt *xt, *next; - - /* list should be empty here, it can be non-empty only in case there - * was an error that caused nft_xt expr to not be initialized fully - * and noone else requested the same expression later. - * - * In this case, the lists contain 0-refcount entries that still - * hold module reference. - */ - list_for_each_entry_safe(xt, next, &nft_target_list, head) { - struct xt_target *target = xt->ops.data; - - if (WARN_ON_ONCE(xt->refcnt)) - continue; - module_put(target->me); - kfree(xt); - } - - list_for_each_entry_safe(xt, next, &nft_match_list, head) { - struct xt_match *match = xt->ops.data; - - if (WARN_ON_ONCE(xt->refcnt)) - continue; - module_put(match->me); - kfree(xt); - } nfnetlink_subsys_unregister(&nfnl_compat_subsys); nft_unregister_expr(&nft_target_type); nft_unregister_expr(&nft_match_type); diff --git a/net/netfilter/nft_counter.c b/net/netfilter/nft_counter.c index a61d7edfc290..1a6b06ce6b5b 100644 --- a/net/netfilter/nft_counter.c +++ b/net/netfilter/nft_counter.c @@ -104,7 +104,7 @@ static void nft_counter_obj_destroy(const struct nft_ctx *ctx, nft_counter_do_destroy(priv); } -static void nft_counter_reset(struct nft_counter_percpu_priv __percpu *priv, +static void nft_counter_reset(struct nft_counter_percpu_priv *priv, struct nft_counter *total) { struct nft_counter *this_cpu; diff --git a/net/netfilter/nft_ct.c b/net/netfilter/nft_ct.c index 586627c361df..7b717fad6cdc 100644 --- a/net/netfilter/nft_ct.c +++ b/net/netfilter/nft_ct.c @@ -870,7 +870,7 @@ static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx, l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]); priv->l4proto = l4num; - l4proto = nf_ct_l4proto_find_get(l4num); + l4proto = nf_ct_l4proto_find(l4num); if (l4proto->l4proto != l4num) { ret = -EOPNOTSUPP; @@ -902,7 +902,6 @@ static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx, err_free_timeout: kfree(timeout); err_proto_put: - nf_ct_l4proto_put(l4proto); return ret; } @@ -913,7 +912,6 @@ static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx, struct nf_ct_timeout *timeout = priv->timeout; nf_ct_untimeout(ctx->net, timeout); - nf_ct_l4proto_put(timeout->l4proto); nf_ct_netns_put(ctx->net, ctx->family); kfree(priv->timeout); } diff --git a/net/netfilter/nft_dynset.c b/net/netfilter/nft_dynset.c index 07d4efd3d851..e461007558e8 100644 --- a/net/netfilter/nft_dynset.c +++ b/net/netfilter/nft_dynset.c @@ -62,9 +62,8 @@ err1: return NULL; } -static void nft_dynset_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_dynset_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) { const struct nft_dynset *priv = nft_expr_priv(expr); struct nft_set *set = priv->set; @@ -235,20 +234,21 @@ err1: return err; } -static void nft_dynset_activate(const struct nft_ctx *ctx, - const struct nft_expr *expr) +static void nft_dynset_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) { struct nft_dynset *priv = nft_expr_priv(expr); - nf_tables_rebind_set(ctx, priv->set, &priv->binding); + nf_tables_deactivate_set(ctx, priv->set, &priv->binding, phase); } -static void nft_dynset_deactivate(const struct nft_ctx *ctx, - const struct nft_expr *expr) +static void nft_dynset_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) { struct nft_dynset *priv = nft_expr_priv(expr); - nf_tables_unbind_set(ctx, priv->set, &priv->binding); + priv->set->use++; } static void nft_dynset_destroy(const struct nft_ctx *ctx, diff --git a/net/netfilter/nft_hash.c b/net/netfilter/nft_hash.c index c2d237144f74..ea658e6c53e3 100644 --- a/net/netfilter/nft_hash.c +++ b/net/netfilter/nft_hash.c @@ -25,7 +25,6 @@ struct nft_jhash { u32 modulus; u32 seed; u32 offset; - struct nft_set *map; }; static void nft_jhash_eval(const struct nft_expr *expr, @@ -42,33 +41,10 @@ static void nft_jhash_eval(const struct nft_expr *expr, regs->data[priv->dreg] = h + priv->offset; } -static void nft_jhash_map_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) -{ - struct nft_jhash *priv = nft_expr_priv(expr); - const void *data = ®s->data[priv->sreg]; - const struct nft_set *map = priv->map; - const struct nft_set_ext *ext; - u32 result; - bool found; - - result = reciprocal_scale(jhash(data, priv->len, priv->seed), - priv->modulus) + priv->offset; - - found = map->ops->lookup(nft_net(pkt), map, &result, &ext); - if (!found) - return; - - nft_data_copy(®s->data[priv->dreg], - nft_set_ext_data(ext), map->dlen); -} - struct nft_symhash { enum nft_registers dreg:8; u32 modulus; u32 offset; - struct nft_set *map; }; static void nft_symhash_eval(const struct nft_expr *expr, @@ -84,28 +60,6 @@ static void nft_symhash_eval(const struct nft_expr *expr, regs->data[priv->dreg] = h + priv->offset; } -static void nft_symhash_map_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) -{ - struct nft_symhash *priv = nft_expr_priv(expr); - struct sk_buff *skb = pkt->skb; - const struct nft_set *map = priv->map; - const struct nft_set_ext *ext; - u32 result; - bool found; - - result = reciprocal_scale(__skb_get_hash_symmetric(skb), - priv->modulus) + priv->offset; - - found = map->ops->lookup(nft_net(pkt), map, &result, &ext); - if (!found) - return; - - nft_data_copy(®s->data[priv->dreg], - nft_set_ext_data(ext), map->dlen); -} - static const struct nla_policy nft_hash_policy[NFTA_HASH_MAX + 1] = { [NFTA_HASH_SREG] = { .type = NLA_U32 }, [NFTA_HASH_DREG] = { .type = NLA_U32 }, @@ -114,9 +68,6 @@ static const struct nla_policy nft_hash_policy[NFTA_HASH_MAX + 1] = { [NFTA_HASH_SEED] = { .type = NLA_U32 }, [NFTA_HASH_OFFSET] = { .type = NLA_U32 }, [NFTA_HASH_TYPE] = { .type = NLA_U32 }, - [NFTA_HASH_SET_NAME] = { .type = NLA_STRING, - .len = NFT_SET_MAXNAMELEN - 1 }, - [NFTA_HASH_SET_ID] = { .type = NLA_U32 }, }; static int nft_jhash_init(const struct nft_ctx *ctx, @@ -166,20 +117,6 @@ static int nft_jhash_init(const struct nft_ctx *ctx, NFT_DATA_VALUE, sizeof(u32)); } -static int nft_jhash_map_init(const struct nft_ctx *ctx, - const struct nft_expr *expr, - const struct nlattr * const tb[]) -{ - struct nft_jhash *priv = nft_expr_priv(expr); - u8 genmask = nft_genmask_next(ctx->net); - - nft_jhash_init(ctx, expr, tb); - priv->map = nft_set_lookup_global(ctx->net, ctx->table, - tb[NFTA_HASH_SET_NAME], - tb[NFTA_HASH_SET_ID], genmask); - return PTR_ERR_OR_ZERO(priv->map); -} - static int nft_symhash_init(const struct nft_ctx *ctx, const struct nft_expr *expr, const struct nlattr * const tb[]) @@ -206,20 +143,6 @@ static int nft_symhash_init(const struct nft_ctx *ctx, NFT_DATA_VALUE, sizeof(u32)); } -static int nft_symhash_map_init(const struct nft_ctx *ctx, - const struct nft_expr *expr, - const struct nlattr * const tb[]) -{ - struct nft_jhash *priv = nft_expr_priv(expr); - u8 genmask = nft_genmask_next(ctx->net); - - nft_symhash_init(ctx, expr, tb); - priv->map = nft_set_lookup_global(ctx->net, ctx->table, - tb[NFTA_HASH_SET_NAME], - tb[NFTA_HASH_SET_ID], genmask); - return PTR_ERR_OR_ZERO(priv->map); -} - static int nft_jhash_dump(struct sk_buff *skb, const struct nft_expr *expr) { @@ -247,18 +170,6 @@ nla_put_failure: return -1; } -static int nft_jhash_map_dump(struct sk_buff *skb, - const struct nft_expr *expr) -{ - const struct nft_jhash *priv = nft_expr_priv(expr); - - if (nft_jhash_dump(skb, expr) || - nla_put_string(skb, NFTA_HASH_SET_NAME, priv->map->name)) - return -1; - - return 0; -} - static int nft_symhash_dump(struct sk_buff *skb, const struct nft_expr *expr) { @@ -279,18 +190,6 @@ nla_put_failure: return -1; } -static int nft_symhash_map_dump(struct sk_buff *skb, - const struct nft_expr *expr) -{ - const struct nft_symhash *priv = nft_expr_priv(expr); - - if (nft_symhash_dump(skb, expr) || - nla_put_string(skb, NFTA_HASH_SET_NAME, priv->map->name)) - return -1; - - return 0; -} - static struct nft_expr_type nft_hash_type; static const struct nft_expr_ops nft_jhash_ops = { .type = &nft_hash_type, @@ -300,14 +199,6 @@ static const struct nft_expr_ops nft_jhash_ops = { .dump = nft_jhash_dump, }; -static const struct nft_expr_ops nft_jhash_map_ops = { - .type = &nft_hash_type, - .size = NFT_EXPR_SIZE(sizeof(struct nft_jhash)), - .eval = nft_jhash_map_eval, - .init = nft_jhash_map_init, - .dump = nft_jhash_map_dump, -}; - static const struct nft_expr_ops nft_symhash_ops = { .type = &nft_hash_type, .size = NFT_EXPR_SIZE(sizeof(struct nft_symhash)), @@ -316,14 +207,6 @@ static const struct nft_expr_ops nft_symhash_ops = { .dump = nft_symhash_dump, }; -static const struct nft_expr_ops nft_symhash_map_ops = { - .type = &nft_hash_type, - .size = NFT_EXPR_SIZE(sizeof(struct nft_symhash)), - .eval = nft_symhash_map_eval, - .init = nft_symhash_map_init, - .dump = nft_symhash_map_dump, -}; - static const struct nft_expr_ops * nft_hash_select_ops(const struct nft_ctx *ctx, const struct nlattr * const tb[]) @@ -336,12 +219,8 @@ nft_hash_select_ops(const struct nft_ctx *ctx, type = ntohl(nla_get_be32(tb[NFTA_HASH_TYPE])); switch (type) { case NFT_HASH_SYM: - if (tb[NFTA_HASH_SET_NAME]) - return &nft_symhash_map_ops; return &nft_symhash_ops; case NFT_HASH_JENKINS: - if (tb[NFTA_HASH_SET_NAME]) - return &nft_jhash_map_ops; return &nft_jhash_ops; default: break; diff --git a/net/netfilter/nft_immediate.c b/net/netfilter/nft_immediate.c index 0777a93211e2..5ec43124cbca 100644 --- a/net/netfilter/nft_immediate.c +++ b/net/netfilter/nft_immediate.c @@ -17,9 +17,9 @@ #include <net/netfilter/nf_tables_core.h> #include <net/netfilter/nf_tables.h> -static void nft_immediate_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_immediate_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) { const struct nft_immediate_expr *priv = nft_expr_priv(expr); @@ -72,10 +72,14 @@ static void nft_immediate_activate(const struct nft_ctx *ctx, } static void nft_immediate_deactivate(const struct nft_ctx *ctx, - const struct nft_expr *expr) + const struct nft_expr *expr, + enum nft_trans_phase phase) { const struct nft_immediate_expr *priv = nft_expr_priv(expr); + if (phase == NFT_TRANS_COMMIT) + return; + return nft_data_release(&priv->data, nft_dreg_to_type(priv->dreg)); } diff --git a/net/netfilter/nft_lookup.c b/net/netfilter/nft_lookup.c index 227b2b15a19c..161c3451a747 100644 --- a/net/netfilter/nft_lookup.c +++ b/net/netfilter/nft_lookup.c @@ -121,20 +121,21 @@ static int nft_lookup_init(const struct nft_ctx *ctx, return 0; } -static void nft_lookup_activate(const struct nft_ctx *ctx, - const struct nft_expr *expr) +static void nft_lookup_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) { struct nft_lookup *priv = nft_expr_priv(expr); - nf_tables_rebind_set(ctx, priv->set, &priv->binding); + nf_tables_deactivate_set(ctx, priv->set, &priv->binding, phase); } -static void nft_lookup_deactivate(const struct nft_ctx *ctx, - const struct nft_expr *expr) +static void nft_lookup_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) { struct nft_lookup *priv = nft_expr_priv(expr); - nf_tables_unbind_set(ctx, priv->set, &priv->binding); + priv->set->use++; } static void nft_lookup_destroy(const struct nft_ctx *ctx, diff --git a/net/netfilter/nft_masq.c b/net/netfilter/nft_masq.c index 9d8655bc1bea..bee156eaa400 100644 --- a/net/netfilter/nft_masq.c +++ b/net/netfilter/nft_masq.c @@ -14,18 +14,24 @@ #include <linux/netfilter/nf_tables.h> #include <net/netfilter/nf_tables.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nft_masq.h> +#include <net/netfilter/ipv4/nf_nat_masquerade.h> +#include <net/netfilter/ipv6/nf_nat_masquerade.h> -const struct nla_policy nft_masq_policy[NFTA_MASQ_MAX + 1] = { +struct nft_masq { + u32 flags; + enum nft_registers sreg_proto_min:8; + enum nft_registers sreg_proto_max:8; +}; + +static const struct nla_policy nft_masq_policy[NFTA_MASQ_MAX + 1] = { [NFTA_MASQ_FLAGS] = { .type = NLA_U32 }, [NFTA_MASQ_REG_PROTO_MIN] = { .type = NLA_U32 }, [NFTA_MASQ_REG_PROTO_MAX] = { .type = NLA_U32 }, }; -EXPORT_SYMBOL_GPL(nft_masq_policy); -int nft_masq_validate(const struct nft_ctx *ctx, - const struct nft_expr *expr, - const struct nft_data **data) +static int nft_masq_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) { int err; @@ -36,11 +42,10 @@ int nft_masq_validate(const struct nft_ctx *ctx, return nft_chain_validate_hooks(ctx->chain, (1 << NF_INET_POST_ROUTING)); } -EXPORT_SYMBOL_GPL(nft_masq_validate); -int nft_masq_init(const struct nft_ctx *ctx, - const struct nft_expr *expr, - const struct nlattr * const tb[]) +static int nft_masq_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) { u32 plen = FIELD_SIZEOF(struct nf_nat_range, min_addr.all); struct nft_masq *priv = nft_expr_priv(expr); @@ -75,9 +80,8 @@ int nft_masq_init(const struct nft_ctx *ctx, return nf_ct_netns_get(ctx->net, ctx->family); } -EXPORT_SYMBOL_GPL(nft_masq_init); -int nft_masq_dump(struct sk_buff *skb, const struct nft_expr *expr) +static int nft_masq_dump(struct sk_buff *skb, const struct nft_expr *expr) { const struct nft_masq *priv = nft_expr_priv(expr); @@ -98,7 +102,157 @@ int nft_masq_dump(struct sk_buff *skb, const struct nft_expr *expr) nla_put_failure: return -1; } -EXPORT_SYMBOL_GPL(nft_masq_dump); + +static void nft_masq_ipv4_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_masq *priv = nft_expr_priv(expr); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + range.flags = priv->flags; + if (priv->sreg_proto_min) { + range.min_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_min]); + range.max_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_max]); + } + regs->verdict.code = nf_nat_masquerade_ipv4(pkt->skb, nft_hook(pkt), + &range, nft_out(pkt)); +} + +static void +nft_masq_ipv4_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV4); +} + +static struct nft_expr_type nft_masq_ipv4_type; +static const struct nft_expr_ops nft_masq_ipv4_ops = { + .type = &nft_masq_ipv4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), + .eval = nft_masq_ipv4_eval, + .init = nft_masq_init, + .destroy = nft_masq_ipv4_destroy, + .dump = nft_masq_dump, + .validate = nft_masq_validate, +}; + +static struct nft_expr_type nft_masq_ipv4_type __read_mostly = { + .family = NFPROTO_IPV4, + .name = "masq", + .ops = &nft_masq_ipv4_ops, + .policy = nft_masq_policy, + .maxattr = NFTA_MASQ_MAX, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NF_TABLES_IPV6 +static void nft_masq_ipv6_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_masq *priv = nft_expr_priv(expr); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + range.flags = priv->flags; + if (priv->sreg_proto_min) { + range.min_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_min]); + range.max_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_max]); + } + regs->verdict.code = nf_nat_masquerade_ipv6(pkt->skb, &range, + nft_out(pkt)); +} + +static void +nft_masq_ipv6_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV6); +} + +static struct nft_expr_type nft_masq_ipv6_type; +static const struct nft_expr_ops nft_masq_ipv6_ops = { + .type = &nft_masq_ipv6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), + .eval = nft_masq_ipv6_eval, + .init = nft_masq_init, + .destroy = nft_masq_ipv6_destroy, + .dump = nft_masq_dump, + .validate = nft_masq_validate, +}; + +static struct nft_expr_type nft_masq_ipv6_type __read_mostly = { + .family = NFPROTO_IPV6, + .name = "masq", + .ops = &nft_masq_ipv6_ops, + .policy = nft_masq_policy, + .maxattr = NFTA_MASQ_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_masq_module_init_ipv6(void) +{ + int ret = nft_register_expr(&nft_masq_ipv6_type); + + if (ret) + return ret; + + ret = nf_nat_masquerade_ipv6_register_notifier(); + if (ret < 0) + nft_unregister_expr(&nft_masq_ipv6_type); + + return ret; +} + +static void nft_masq_module_exit_ipv6(void) +{ + nft_unregister_expr(&nft_masq_ipv6_type); + nf_nat_masquerade_ipv6_unregister_notifier(); +} +#else +static inline int nft_masq_module_init_ipv6(void) { return 0; } +static inline void nft_masq_module_exit_ipv6(void) {} +#endif + +static int __init nft_masq_module_init(void) +{ + int ret; + + ret = nft_masq_module_init_ipv6(); + if (ret < 0) + return ret; + + ret = nft_register_expr(&nft_masq_ipv4_type); + if (ret < 0) { + nft_masq_module_exit_ipv6(); + return ret; + } + + ret = nf_nat_masquerade_ipv4_register_notifier(); + if (ret < 0) { + nft_masq_module_exit_ipv6(); + nft_unregister_expr(&nft_masq_ipv4_type); + return ret; + } + + return ret; +} + +static void __exit nft_masq_module_exit(void) +{ + nft_masq_module_exit_ipv6(); + nft_unregister_expr(&nft_masq_ipv4_type); + nf_nat_masquerade_ipv4_unregister_notifier(); +} + +module_init(nft_masq_module_init); +module_exit(nft_masq_module_exit); MODULE_LICENSE("GPL"); MODULE_AUTHOR("Arturo Borrero Gonzalez <arturo@debian.org>"); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET6, "masq"); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET, "masq"); diff --git a/net/netfilter/nft_meta.c b/net/netfilter/nft_meta.c index 6df486c5ebd3..987d2d6ce624 100644 --- a/net/netfilter/nft_meta.c +++ b/net/netfilter/nft_meta.c @@ -244,6 +244,16 @@ void nft_meta_get_eval(const struct nft_expr *expr, strncpy((char *)dest, p->br->dev->name, IFNAMSIZ); return; #endif + case NFT_META_IIFKIND: + if (in == NULL || in->rtnl_link_ops == NULL) + goto err; + strncpy((char *)dest, in->rtnl_link_ops->kind, IFNAMSIZ); + break; + case NFT_META_OIFKIND: + if (out == NULL || out->rtnl_link_ops == NULL) + goto err; + strncpy((char *)dest, out->rtnl_link_ops->kind, IFNAMSIZ); + break; default: WARN_ON(1); goto err; @@ -340,6 +350,8 @@ static int nft_meta_get_init(const struct nft_ctx *ctx, break; case NFT_META_IIFNAME: case NFT_META_OIFNAME: + case NFT_META_IIFKIND: + case NFT_META_OIFKIND: len = IFNAMSIZ; break; case NFT_META_PRANDOM: diff --git a/net/netfilter/nft_nat.c b/net/netfilter/nft_nat.c index c15807d10b91..e93aed9bda88 100644 --- a/net/netfilter/nft_nat.c +++ b/net/netfilter/nft_nat.c @@ -21,9 +21,7 @@ #include <linux/netfilter/nf_tables.h> #include <net/netfilter/nf_conntrack.h> #include <net/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> #include <net/netfilter/nf_tables.h> -#include <net/netfilter/nf_nat_l3proto.h> #include <net/ip.h> struct nft_nat { diff --git a/net/netfilter/nft_objref.c b/net/netfilter/nft_objref.c index a3185ca2a3a9..457a9ceb46af 100644 --- a/net/netfilter/nft_objref.c +++ b/net/netfilter/nft_objref.c @@ -38,7 +38,8 @@ static int nft_objref_init(const struct nft_ctx *ctx, return -EINVAL; objtype = ntohl(nla_get_be32(tb[NFTA_OBJREF_IMM_TYPE])); - obj = nft_obj_lookup(ctx->table, tb[NFTA_OBJREF_IMM_NAME], objtype, + obj = nft_obj_lookup(ctx->net, ctx->table, + tb[NFTA_OBJREF_IMM_NAME], objtype, genmask); if (IS_ERR(obj)) return -ENOENT; @@ -53,7 +54,7 @@ static int nft_objref_dump(struct sk_buff *skb, const struct nft_expr *expr) { const struct nft_object *obj = nft_objref_priv(expr); - if (nla_put_string(skb, NFTA_OBJREF_IMM_NAME, obj->name) || + if (nla_put_string(skb, NFTA_OBJREF_IMM_NAME, obj->key.name) || nla_put_be32(skb, NFTA_OBJREF_IMM_TYPE, htonl(obj->ops->type->type))) goto nla_put_failure; @@ -155,20 +156,21 @@ nla_put_failure: return -1; } -static void nft_objref_map_activate(const struct nft_ctx *ctx, - const struct nft_expr *expr) +static void nft_objref_map_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) { struct nft_objref_map *priv = nft_expr_priv(expr); - nf_tables_rebind_set(ctx, priv->set, &priv->binding); + nf_tables_deactivate_set(ctx, priv->set, &priv->binding, phase); } -static void nft_objref_map_deactivate(const struct nft_ctx *ctx, - const struct nft_expr *expr) +static void nft_objref_map_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) { struct nft_objref_map *priv = nft_expr_priv(expr); - nf_tables_unbind_set(ctx, priv->set, &priv->binding); + priv->set->use++; } static void nft_objref_map_destroy(const struct nft_ctx *ctx, diff --git a/net/netfilter/nft_payload.c b/net/netfilter/nft_payload.c index e110b0ebbf58..54e15de4b79a 100644 --- a/net/netfilter/nft_payload.c +++ b/net/netfilter/nft_payload.c @@ -70,9 +70,9 @@ nft_payload_copy_vlan(u32 *d, const struct sk_buff *skb, u8 offset, u8 len) return skb_copy_bits(skb, offset + mac_off, dst_u8, len) == 0; } -static void nft_payload_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_payload_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) { const struct nft_payload *priv = nft_expr_priv(expr); const struct sk_buff *skb = pkt->skb; diff --git a/net/netfilter/nft_quota.c b/net/netfilter/nft_quota.c index 0ed124a93fcf..354cde67bca9 100644 --- a/net/netfilter/nft_quota.c +++ b/net/netfilter/nft_quota.c @@ -61,7 +61,7 @@ static void nft_quota_obj_eval(struct nft_object *obj, if (overquota && !test_and_set_bit(NFT_QUOTA_DEPLETED_BIT, &priv->flags)) - nft_obj_notify(nft_net(pkt), obj->table, obj, 0, 0, + nft_obj_notify(nft_net(pkt), obj->key.table, obj, 0, 0, NFT_MSG_NEWOBJ, nft_pf(pkt), 0, GFP_ATOMIC); } diff --git a/net/netfilter/nft_range.c b/net/netfilter/nft_range.c index cedb96c3619f..529ac8acb19d 100644 --- a/net/netfilter/nft_range.c +++ b/net/netfilter/nft_range.c @@ -23,9 +23,8 @@ struct nft_range_expr { enum nft_range_ops op:8; }; -static void nft_range_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_range_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) { const struct nft_range_expr *priv = nft_expr_priv(expr); int d1, d2; diff --git a/net/netfilter/nft_redir.c b/net/netfilter/nft_redir.c index c64cbe78dee7..f8092926f704 100644 --- a/net/netfilter/nft_redir.c +++ b/net/netfilter/nft_redir.c @@ -13,19 +13,24 @@ #include <linux/netfilter.h> #include <linux/netfilter/nf_tables.h> #include <net/netfilter/nf_nat.h> +#include <net/netfilter/nf_nat_redirect.h> #include <net/netfilter/nf_tables.h> -#include <net/netfilter/nft_redir.h> -const struct nla_policy nft_redir_policy[NFTA_REDIR_MAX + 1] = { +struct nft_redir { + enum nft_registers sreg_proto_min:8; + enum nft_registers sreg_proto_max:8; + u16 flags; +}; + +static const struct nla_policy nft_redir_policy[NFTA_REDIR_MAX + 1] = { [NFTA_REDIR_REG_PROTO_MIN] = { .type = NLA_U32 }, [NFTA_REDIR_REG_PROTO_MAX] = { .type = NLA_U32 }, [NFTA_REDIR_FLAGS] = { .type = NLA_U32 }, }; -EXPORT_SYMBOL_GPL(nft_redir_policy); -int nft_redir_validate(const struct nft_ctx *ctx, - const struct nft_expr *expr, - const struct nft_data **data) +static int nft_redir_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) { int err; @@ -37,11 +42,10 @@ int nft_redir_validate(const struct nft_ctx *ctx, (1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_OUT)); } -EXPORT_SYMBOL_GPL(nft_redir_validate); -int nft_redir_init(const struct nft_ctx *ctx, - const struct nft_expr *expr, - const struct nlattr * const tb[]) +static int nft_redir_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) { struct nft_redir *priv = nft_expr_priv(expr); unsigned int plen; @@ -77,7 +81,6 @@ int nft_redir_init(const struct nft_ctx *ctx, return nf_ct_netns_get(ctx->net, ctx->family); } -EXPORT_SYMBOL_GPL(nft_redir_init); int nft_redir_dump(struct sk_buff *skb, const struct nft_expr *expr) { @@ -101,7 +104,134 @@ int nft_redir_dump(struct sk_buff *skb, const struct nft_expr *expr) nla_put_failure: return -1; } -EXPORT_SYMBOL_GPL(nft_redir_dump); + +static void nft_redir_ipv4_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_redir *priv = nft_expr_priv(expr); + struct nf_nat_ipv4_multi_range_compat mr; + + memset(&mr, 0, sizeof(mr)); + if (priv->sreg_proto_min) { + mr.range[0].min.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_min]); + mr.range[0].max.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_max]); + mr.range[0].flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + } + + mr.range[0].flags |= priv->flags; + + regs->verdict.code = nf_nat_redirect_ipv4(pkt->skb, &mr, nft_hook(pkt)); +} + +static void +nft_redir_ipv4_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV4); +} + +static struct nft_expr_type nft_redir_ipv4_type; +static const struct nft_expr_ops nft_redir_ipv4_ops = { + .type = &nft_redir_ipv4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), + .eval = nft_redir_ipv4_eval, + .init = nft_redir_init, + .destroy = nft_redir_ipv4_destroy, + .dump = nft_redir_dump, + .validate = nft_redir_validate, +}; + +static struct nft_expr_type nft_redir_ipv4_type __read_mostly = { + .family = NFPROTO_IPV4, + .name = "redir", + .ops = &nft_redir_ipv4_ops, + .policy = nft_redir_policy, + .maxattr = NFTA_REDIR_MAX, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NF_TABLES_IPV6 +static void nft_redir_ipv6_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_redir *priv = nft_expr_priv(expr); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + if (priv->sreg_proto_min) { + range.min_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_min]); + range.max_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_max]); + range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + } + + range.flags |= priv->flags; + + regs->verdict.code = + nf_nat_redirect_ipv6(pkt->skb, &range, nft_hook(pkt)); +} + +static void +nft_redir_ipv6_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV6); +} + +static struct nft_expr_type nft_redir_ipv6_type; +static const struct nft_expr_ops nft_redir_ipv6_ops = { + .type = &nft_redir_ipv6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), + .eval = nft_redir_ipv6_eval, + .init = nft_redir_init, + .destroy = nft_redir_ipv6_destroy, + .dump = nft_redir_dump, + .validate = nft_redir_validate, +}; + +static struct nft_expr_type nft_redir_ipv6_type __read_mostly = { + .family = NFPROTO_IPV6, + .name = "redir", + .ops = &nft_redir_ipv6_ops, + .policy = nft_redir_policy, + .maxattr = NFTA_REDIR_MAX, + .owner = THIS_MODULE, +}; +#endif + +static int __init nft_redir_module_init(void) +{ + int ret = nft_register_expr(&nft_redir_ipv4_type); + + if (ret) + return ret; + +#ifdef CONFIG_NF_TABLES_IPV6 + ret = nft_register_expr(&nft_redir_ipv6_type); + if (ret) { + nft_unregister_expr(&nft_redir_ipv4_type); + return ret; + } +#endif + + return ret; +} + +static void __exit nft_redir_module_exit(void) +{ + nft_unregister_expr(&nft_redir_ipv4_type); +#ifdef CONFIG_NF_TABLES_IPV6 + nft_unregister_expr(&nft_redir_ipv6_type); +#endif +} + +module_init(nft_redir_module_init); +module_exit(nft_redir_module_exit); MODULE_LICENSE("GPL"); MODULE_AUTHOR("Arturo Borrero Gonzalez <arturo@debian.org>"); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET4, "redir"); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET6, "redir"); diff --git a/net/netfilter/nft_rt.c b/net/netfilter/nft_rt.c index f35fa33913ae..c48daed5c46b 100644 --- a/net/netfilter/nft_rt.c +++ b/net/netfilter/nft_rt.c @@ -53,9 +53,9 @@ static u16 get_tcpmss(const struct nft_pktinfo *pkt, const struct dst_entry *skb return mtu - minlen; } -static void nft_rt_get_eval(const struct nft_expr *expr, - struct nft_regs *regs, - const struct nft_pktinfo *pkt) +void nft_rt_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) { const struct nft_rt *priv = nft_expr_priv(expr); const struct sk_buff *skb = pkt->skb; diff --git a/net/netfilter/nft_set_hash.c b/net/netfilter/nft_set_hash.c index 339a9dd1c832..03df08801e28 100644 --- a/net/netfilter/nft_set_hash.c +++ b/net/netfilter/nft_set_hash.c @@ -442,15 +442,6 @@ static void *nft_hash_get(const struct net *net, const struct nft_set *set, return ERR_PTR(-ENOENT); } -/* nft_hash_select_ops() makes sure key size can be either 2 or 4 bytes . */ -static inline u32 nft_hash_key(const u32 *key, u32 klen) -{ - if (klen == 4) - return *key; - - return *(u16 *)key; -} - static bool nft_hash_lookup_fast(const struct net *net, const struct nft_set *set, const u32 *key, const struct nft_set_ext **ext) @@ -460,11 +451,11 @@ static bool nft_hash_lookup_fast(const struct net *net, const struct nft_hash_elem *he; u32 hash, k1, k2; - k1 = nft_hash_key(key, set->klen); + k1 = *key; hash = jhash_1word(k1, priv->seed); hash = reciprocal_scale(hash, priv->buckets); hlist_for_each_entry_rcu(he, &priv->table[hash], node) { - k2 = nft_hash_key(nft_set_ext_key(&he->ext)->data, set->klen); + k2 = *(u32 *)nft_set_ext_key(&he->ext)->data; if (k1 == k2 && nft_set_elem_active(&he->ext, genmask)) { *ext = &he->ext; @@ -474,6 +465,23 @@ static bool nft_hash_lookup_fast(const struct net *net, return false; } +static u32 nft_jhash(const struct nft_set *set, const struct nft_hash *priv, + const struct nft_set_ext *ext) +{ + const struct nft_data *key = nft_set_ext_key(ext); + u32 hash, k1; + + if (set->klen == 4) { + k1 = *(u32 *)key; + hash = jhash_1word(k1, priv->seed); + } else { + hash = jhash(key, set->klen, priv->seed); + } + hash = reciprocal_scale(hash, priv->buckets); + + return hash; +} + static int nft_hash_insert(const struct net *net, const struct nft_set *set, const struct nft_set_elem *elem, struct nft_set_ext **ext) @@ -483,8 +491,7 @@ static int nft_hash_insert(const struct net *net, const struct nft_set *set, u8 genmask = nft_genmask_next(net); u32 hash; - hash = jhash(nft_set_ext_key(&this->ext), set->klen, priv->seed); - hash = reciprocal_scale(hash, priv->buckets); + hash = nft_jhash(set, priv, &this->ext); hlist_for_each_entry(he, &priv->table[hash], node) { if (!memcmp(nft_set_ext_key(&this->ext), nft_set_ext_key(&he->ext), set->klen) && @@ -523,10 +530,9 @@ static void *nft_hash_deactivate(const struct net *net, u8 genmask = nft_genmask_next(net); u32 hash; - hash = jhash(nft_set_ext_key(&this->ext), set->klen, priv->seed); - hash = reciprocal_scale(hash, priv->buckets); + hash = nft_jhash(set, priv, &this->ext); hlist_for_each_entry(he, &priv->table[hash], node) { - if (!memcmp(nft_set_ext_key(&this->ext), &elem->key.val, + if (!memcmp(nft_set_ext_key(&he->ext), &elem->key.val, set->klen) && nft_set_elem_active(&he->ext, genmask)) { nft_set_elem_change_active(net, set, &he->ext); diff --git a/net/netfilter/nft_tunnel.c b/net/netfilter/nft_tunnel.c index 3a15f219e4e7..b113fcac94e1 100644 --- a/net/netfilter/nft_tunnel.c +++ b/net/netfilter/nft_tunnel.c @@ -15,6 +15,7 @@ struct nft_tunnel { enum nft_tunnel_keys key:8; enum nft_registers dreg:8; + enum nft_tunnel_mode mode:8; }; static void nft_tunnel_get_eval(const struct nft_expr *expr, @@ -29,14 +30,32 @@ static void nft_tunnel_get_eval(const struct nft_expr *expr, switch (priv->key) { case NFT_TUNNEL_PATH: - nft_reg_store8(dest, !!tun_info); + if (!tun_info) { + nft_reg_store8(dest, false); + return; + } + if (priv->mode == NFT_TUNNEL_MODE_NONE || + (priv->mode == NFT_TUNNEL_MODE_RX && + !(tun_info->mode & IP_TUNNEL_INFO_TX)) || + (priv->mode == NFT_TUNNEL_MODE_TX && + (tun_info->mode & IP_TUNNEL_INFO_TX))) + nft_reg_store8(dest, true); + else + nft_reg_store8(dest, false); break; case NFT_TUNNEL_ID: if (!tun_info) { regs->verdict.code = NFT_BREAK; return; } - *dest = ntohl(tunnel_id_to_key32(tun_info->key.tun_id)); + if (priv->mode == NFT_TUNNEL_MODE_NONE || + (priv->mode == NFT_TUNNEL_MODE_RX && + !(tun_info->mode & IP_TUNNEL_INFO_TX)) || + (priv->mode == NFT_TUNNEL_MODE_TX && + (tun_info->mode & IP_TUNNEL_INFO_TX))) + *dest = ntohl(tunnel_id_to_key32(tun_info->key.tun_id)); + else + regs->verdict.code = NFT_BREAK; break; default: WARN_ON(1); @@ -47,6 +66,7 @@ static void nft_tunnel_get_eval(const struct nft_expr *expr, static const struct nla_policy nft_tunnel_policy[NFTA_TUNNEL_MAX + 1] = { [NFTA_TUNNEL_KEY] = { .type = NLA_U32 }, [NFTA_TUNNEL_DREG] = { .type = NLA_U32 }, + [NFTA_TUNNEL_MODE] = { .type = NLA_U32 }, }; static int nft_tunnel_get_init(const struct nft_ctx *ctx, @@ -74,6 +94,14 @@ static int nft_tunnel_get_init(const struct nft_ctx *ctx, priv->dreg = nft_parse_register(tb[NFTA_TUNNEL_DREG]); + if (tb[NFTA_TUNNEL_MODE]) { + priv->mode = ntohl(nla_get_be32(tb[NFTA_TUNNEL_MODE])); + if (priv->mode > NFT_TUNNEL_MODE_MAX) + return -EOPNOTSUPP; + } else { + priv->mode = NFT_TUNNEL_MODE_NONE; + } + return nft_validate_register_store(ctx, priv->dreg, NULL, NFT_DATA_VALUE, len); } @@ -87,6 +115,8 @@ static int nft_tunnel_get_dump(struct sk_buff *skb, goto nla_put_failure; if (nft_dump_register(skb, NFTA_TUNNEL_DREG, priv->dreg)) goto nla_put_failure; + if (nla_put_be32(skb, NFTA_TUNNEL_MODE, htonl(priv->mode))) + goto nla_put_failure; return 0; nla_put_failure: @@ -376,6 +406,13 @@ static int nft_tunnel_obj_init(const struct nft_ctx *ctx, return -ENOMEM; memcpy(&md->u.tun_info, &info, sizeof(info)); +#ifdef CONFIG_DST_CACHE + err = dst_cache_init(&md->u.tun_info.dst_cache, GFP_KERNEL); + if (err < 0) { + metadata_dst_free(md); + return err; + } +#endif ip_tunnel_info_opts_set(&md->u.tun_info, &priv->opts.u, priv->opts.len, priv->opts.flags); priv->md = md; diff --git a/net/netfilter/utils.c b/net/netfilter/utils.c index e8da9a9bba73..06dc55590441 100644 --- a/net/netfilter/utils.c +++ b/net/netfilter/utils.c @@ -162,7 +162,7 @@ EXPORT_SYMBOL_GPL(nf_checksum_partial); int nf_route(struct net *net, struct dst_entry **dst, struct flowi *fl, bool strict, unsigned short family) { - const struct nf_ipv6_ops *v6ops; + const struct nf_ipv6_ops *v6ops __maybe_unused; int ret = 0; switch (family) { @@ -170,9 +170,7 @@ int nf_route(struct net *net, struct dst_entry **dst, struct flowi *fl, ret = nf_ip_route(net, dst, fl, strict); break; case AF_INET6: - v6ops = rcu_dereference(nf_ipv6_ops); - if (v6ops) - ret = v6ops->route(net, dst, fl, strict); + ret = nf_ip6_route(net, dst, fl, strict); break; } @@ -180,6 +178,25 @@ int nf_route(struct net *net, struct dst_entry **dst, struct flowi *fl, } EXPORT_SYMBOL_GPL(nf_route); +static int nf_ip_reroute(struct sk_buff *skb, const struct nf_queue_entry *entry) +{ +#ifdef CONFIG_INET + const struct ip_rt_info *rt_info = nf_queue_entry_reroute(entry); + + if (entry->state.hook == NF_INET_LOCAL_OUT) { + const struct iphdr *iph = ip_hdr(skb); + + if (!(iph->tos == rt_info->tos && + skb->mark == rt_info->mark && + iph->daddr == rt_info->daddr && + iph->saddr == rt_info->saddr)) + return ip_route_me_harder(entry->state.net, skb, + RTN_UNSPEC); + } +#endif + return 0; +} + int nf_reroute(struct sk_buff *skb, struct nf_queue_entry *entry) { const struct nf_ipv6_ops *v6ops; diff --git a/net/netfilter/x_tables.c b/net/netfilter/x_tables.c index aecadd471e1d..e5e5c64df8d1 100644 --- a/net/netfilter/x_tables.c +++ b/net/netfilter/x_tables.c @@ -461,7 +461,7 @@ int xt_check_proc_name(const char *name, unsigned int size) EXPORT_SYMBOL(xt_check_proc_name); int xt_check_match(struct xt_mtchk_param *par, - unsigned int size, u_int8_t proto, bool inv_proto) + unsigned int size, u16 proto, bool inv_proto) { int ret; @@ -984,7 +984,7 @@ bool xt_find_jump_offset(const unsigned int *offsets, EXPORT_SYMBOL(xt_find_jump_offset); int xt_check_target(struct xt_tgchk_param *par, - unsigned int size, u_int8_t proto, bool inv_proto) + unsigned int size, u16 proto, bool inv_proto) { int ret; @@ -1899,7 +1899,7 @@ static int __init xt_init(void) seqcount_init(&per_cpu(xt_recseq, i)); } - xt = kmalloc_array(NFPROTO_NUMPROTO, sizeof(struct xt_af), GFP_KERNEL); + xt = kcalloc(NFPROTO_NUMPROTO, sizeof(struct xt_af), GFP_KERNEL); if (!xt) return -ENOMEM; diff --git a/net/netfilter/xt_CT.c b/net/netfilter/xt_CT.c index 2c7a4b80206f..0fa863f57575 100644 --- a/net/netfilter/xt_CT.c +++ b/net/netfilter/xt_CT.c @@ -159,7 +159,7 @@ xt_ct_set_timeout(struct nf_conn *ct, const struct xt_tgchk_param *par, /* Make sure the timeout policy matches any existing protocol tracker, * otherwise default to generic. */ - l4proto = __nf_ct_l4proto_find(proto); + l4proto = nf_ct_l4proto_find(proto); if (timeout->l4proto->l4proto != l4proto->l4proto) { ret = -EINVAL; pr_info_ratelimited("Timeout policy `%s' can only be used by L%d protocol number %d\n", diff --git a/net/netfilter/xt_IDLETIMER.c b/net/netfilter/xt_IDLETIMER.c index eb4cbd244c3d..5f9b37e12801 100644 --- a/net/netfilter/xt_IDLETIMER.c +++ b/net/netfilter/xt_IDLETIMER.c @@ -41,19 +41,13 @@ #include <linux/workqueue.h> #include <linux/sysfs.h> -struct idletimer_tg_attr { - struct attribute attr; - ssize_t (*show)(struct kobject *kobj, - struct attribute *attr, char *buf); -}; - struct idletimer_tg { struct list_head entry; struct timer_list timer; struct work_struct work; struct kobject *kobj; - struct idletimer_tg_attr attr; + struct device_attribute attr; unsigned int refcnt; }; @@ -76,15 +70,15 @@ struct idletimer_tg *__idletimer_tg_find_by_label(const char *label) return NULL; } -static ssize_t idletimer_tg_show(struct kobject *kobj, struct attribute *attr, - char *buf) +static ssize_t idletimer_tg_show(struct device *dev, + struct device_attribute *attr, char *buf) { struct idletimer_tg *timer; unsigned long expires = 0; mutex_lock(&list_mutex); - timer = __idletimer_tg_find_by_label(attr->name); + timer = __idletimer_tg_find_by_label(attr->attr.name); if (timer) expires = timer->timer.expires; diff --git a/net/netfilter/xt_addrtype.c b/net/netfilter/xt_addrtype.c index 89e281b3bfc2..29987ff03621 100644 --- a/net/netfilter/xt_addrtype.c +++ b/net/netfilter/xt_addrtype.c @@ -36,7 +36,6 @@ MODULE_ALIAS("ip6t_addrtype"); static u32 match_lookup_rt6(struct net *net, const struct net_device *dev, const struct in6_addr *addr, u16 mask) { - const struct nf_ipv6_ops *v6ops; struct flowi6 flow; struct rt6_info *rt; u32 ret = 0; @@ -47,18 +46,13 @@ static u32 match_lookup_rt6(struct net *net, const struct net_device *dev, if (dev) flow.flowi6_oif = dev->ifindex; - v6ops = nf_get_ipv6_ops(); - if (v6ops) { - if (dev && (mask & XT_ADDRTYPE_LOCAL)) { - if (v6ops->chk_addr(net, addr, dev, true)) - ret = XT_ADDRTYPE_LOCAL; - } - route_err = v6ops->route(net, (struct dst_entry **)&rt, - flowi6_to_flowi(&flow), false); - } else { - route_err = 1; + if (dev && (mask & XT_ADDRTYPE_LOCAL)) { + if (nf_ipv6_chk_addr(net, addr, dev, true)) + ret = XT_ADDRTYPE_LOCAL; } + route_err = nf_ip6_route(net, (struct dst_entry **)&rt, + flowi6_to_flowi(&flow), false); if (route_err) return XT_ADDRTYPE_UNREACHABLE; diff --git a/net/netfilter/xt_nat.c b/net/netfilter/xt_nat.c index ac91170fc8c8..61eabd171186 100644 --- a/net/netfilter/xt_nat.c +++ b/net/netfilter/xt_nat.c @@ -14,7 +14,7 @@ #include <linux/skbuff.h> #include <linux/netfilter.h> #include <linux/netfilter/x_tables.h> -#include <net/netfilter/nf_nat_core.h> +#include <net/netfilter/nf_nat.h> static int xt_nat_checkentry_v0(const struct xt_tgchk_param *par) { diff --git a/net/netfilter/xt_physdev.c b/net/netfilter/xt_physdev.c index 4034d70bff39..b2e39cb6a590 100644 --- a/net/netfilter/xt_physdev.c +++ b/net/netfilter/xt_physdev.c @@ -96,8 +96,7 @@ match_outdev: static int physdev_mt_check(const struct xt_mtchk_param *par) { const struct xt_physdev_info *info = par->matchinfo; - - br_netfilter_enable(); + static bool brnf_probed __read_mostly; if (!(info->bitmask & XT_PHYSDEV_OP_MASK) || info->bitmask & ~XT_PHYSDEV_OP_MASK) @@ -111,6 +110,12 @@ static int physdev_mt_check(const struct xt_mtchk_param *par) if (par->hook_mask & (1 << NF_INET_LOCAL_OUT)) return -EINVAL; } + + if (!brnf_probed) { + brnf_probed = true; + request_module("br_netfilter"); + } + return 0; } diff --git a/net/netfilter/xt_recent.c b/net/netfilter/xt_recent.c index f44de4bc2100..1664d2ec8b2f 100644 --- a/net/netfilter/xt_recent.c +++ b/net/netfilter/xt_recent.c @@ -337,7 +337,6 @@ static int recent_mt_check(const struct xt_mtchk_param *par, unsigned int nstamp_mask; unsigned int i; int ret = -EINVAL; - size_t sz; net_get_random_once(&hash_rnd, sizeof(hash_rnd)); @@ -387,8 +386,7 @@ static int recent_mt_check(const struct xt_mtchk_param *par, goto out; } - sz = sizeof(*t) + sizeof(t->iphash[0]) * ip_list_hash_size; - t = kvzalloc(sz, GFP_KERNEL); + t = kvzalloc(struct_size(t, iphash, ip_list_hash_size), GFP_KERNEL); if (t == NULL) { ret = -ENOMEM; goto out; diff --git a/net/netlabel/netlabel_kapi.c b/net/netlabel/netlabel_kapi.c index ea7c67050792..ee3e5b6471a6 100644 --- a/net/netlabel/netlabel_kapi.c +++ b/net/netlabel/netlabel_kapi.c @@ -903,7 +903,8 @@ int netlbl_bitmap_walk(const unsigned char *bitmap, u32 bitmap_len, (state == 0 && (byte & bitmask) == 0)) return bit_spot; - bit_spot++; + if (++bit_spot >= bitmap_len) + return -1; bitmask >>= 1; if (bitmask == 0) { byte = bitmap[++byte_offset]; diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 3c023d6120f6..f28e937320a3 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -1371,6 +1371,14 @@ int netlink_has_listeners(struct sock *sk, unsigned int group) } EXPORT_SYMBOL_GPL(netlink_has_listeners); +bool netlink_strict_get_check(struct sk_buff *skb) +{ + const struct netlink_sock *nlk = nlk_sk(NETLINK_CB(skb).sk); + + return nlk->flags & NETLINK_F_STRICT_CHK; +} +EXPORT_SYMBOL_GPL(netlink_strict_get_check); + static int netlink_broadcast_deliver(struct sock *sk, struct sk_buff *skb) { struct netlink_sock *nlk = nlk_sk(sk); @@ -2541,15 +2549,7 @@ struct nl_seq_iter { static int netlink_walk_start(struct nl_seq_iter *iter) { - int err; - - err = rhashtable_walk_init(&nl_table[iter->link].hash, &iter->hti, - GFP_KERNEL); - if (err) { - iter->link = MAX_LINKS; - return err; - } - + rhashtable_walk_enter(&nl_table[iter->link].hash, &iter->hti); rhashtable_walk_start(&iter->hti); return 0; diff --git a/net/netrom/nr_timer.c b/net/netrom/nr_timer.c index cbd51ed5a2d7..908e53ab47a4 100644 --- a/net/netrom/nr_timer.c +++ b/net/netrom/nr_timer.c @@ -52,21 +52,21 @@ void nr_start_t1timer(struct sock *sk) { struct nr_sock *nr = nr_sk(sk); - mod_timer(&nr->t1timer, jiffies + nr->t1); + sk_reset_timer(sk, &nr->t1timer, jiffies + nr->t1); } void nr_start_t2timer(struct sock *sk) { struct nr_sock *nr = nr_sk(sk); - mod_timer(&nr->t2timer, jiffies + nr->t2); + sk_reset_timer(sk, &nr->t2timer, jiffies + nr->t2); } void nr_start_t4timer(struct sock *sk) { struct nr_sock *nr = nr_sk(sk); - mod_timer(&nr->t4timer, jiffies + nr->t4); + sk_reset_timer(sk, &nr->t4timer, jiffies + nr->t4); } void nr_start_idletimer(struct sock *sk) @@ -74,37 +74,37 @@ void nr_start_idletimer(struct sock *sk) struct nr_sock *nr = nr_sk(sk); if (nr->idle > 0) - mod_timer(&nr->idletimer, jiffies + nr->idle); + sk_reset_timer(sk, &nr->idletimer, jiffies + nr->idle); } void nr_start_heartbeat(struct sock *sk) { - mod_timer(&sk->sk_timer, jiffies + 5 * HZ); + sk_reset_timer(sk, &sk->sk_timer, jiffies + 5 * HZ); } void nr_stop_t1timer(struct sock *sk) { - del_timer(&nr_sk(sk)->t1timer); + sk_stop_timer(sk, &nr_sk(sk)->t1timer); } void nr_stop_t2timer(struct sock *sk) { - del_timer(&nr_sk(sk)->t2timer); + sk_stop_timer(sk, &nr_sk(sk)->t2timer); } void nr_stop_t4timer(struct sock *sk) { - del_timer(&nr_sk(sk)->t4timer); + sk_stop_timer(sk, &nr_sk(sk)->t4timer); } void nr_stop_idletimer(struct sock *sk) { - del_timer(&nr_sk(sk)->idletimer); + sk_stop_timer(sk, &nr_sk(sk)->idletimer); } void nr_stop_heartbeat(struct sock *sk) { - del_timer(&sk->sk_timer); + sk_stop_timer(sk, &sk->sk_timer); } int nr_t1timer_running(struct sock *sk) diff --git a/net/nfc/llcp_commands.c b/net/nfc/llcp_commands.c index 6a196e438b6c..d1fc019e932e 100644 --- a/net/nfc/llcp_commands.c +++ b/net/nfc/llcp_commands.c @@ -419,6 +419,10 @@ int nfc_llcp_send_connect(struct nfc_llcp_sock *sock) sock->service_name, sock->service_name_len, &service_name_tlv_length); + if (!service_name_tlv) { + err = -ENOMEM; + goto error_tlv; + } size += service_name_tlv_length; } @@ -429,9 +433,17 @@ int nfc_llcp_send_connect(struct nfc_llcp_sock *sock) miux_tlv = nfc_llcp_build_tlv(LLCP_TLV_MIUX, (u8 *)&miux, 0, &miux_tlv_length); + if (!miux_tlv) { + err = -ENOMEM; + goto error_tlv; + } size += miux_tlv_length; rw_tlv = nfc_llcp_build_tlv(LLCP_TLV_RW, &rw, 0, &rw_tlv_length); + if (!rw_tlv) { + err = -ENOMEM; + goto error_tlv; + } size += rw_tlv_length; pr_debug("SKB size %d SN length %zu\n", size, sock->service_name_len); @@ -484,9 +496,17 @@ int nfc_llcp_send_cc(struct nfc_llcp_sock *sock) miux_tlv = nfc_llcp_build_tlv(LLCP_TLV_MIUX, (u8 *)&miux, 0, &miux_tlv_length); + if (!miux_tlv) { + err = -ENOMEM; + goto error_tlv; + } size += miux_tlv_length; rw_tlv = nfc_llcp_build_tlv(LLCP_TLV_RW, &rw, 0, &rw_tlv_length); + if (!rw_tlv) { + err = -ENOMEM; + goto error_tlv; + } size += rw_tlv_length; skb = llcp_allocate_pdu(sock, LLCP_PDU_CC, size); diff --git a/net/nfc/llcp_core.c b/net/nfc/llcp_core.c index ef4026a23e80..4fa015208aab 100644 --- a/net/nfc/llcp_core.c +++ b/net/nfc/llcp_core.c @@ -532,10 +532,10 @@ static u8 nfc_llcp_reserve_sdp_ssap(struct nfc_llcp_local *local) static int nfc_llcp_build_gb(struct nfc_llcp_local *local) { - u8 *gb_cur, *version_tlv, version, version_length; - u8 *lto_tlv, lto_length; - u8 *wks_tlv, wks_length; - u8 *miux_tlv, miux_length; + u8 *gb_cur, version, version_length; + u8 lto_length, wks_length, miux_length; + u8 *version_tlv = NULL, *lto_tlv = NULL, + *wks_tlv = NULL, *miux_tlv = NULL; __be16 wks = cpu_to_be16(local->local_wks); u8 gb_len = 0; int ret = 0; @@ -543,17 +543,33 @@ static int nfc_llcp_build_gb(struct nfc_llcp_local *local) version = LLCP_VERSION_11; version_tlv = nfc_llcp_build_tlv(LLCP_TLV_VERSION, &version, 1, &version_length); + if (!version_tlv) { + ret = -ENOMEM; + goto out; + } gb_len += version_length; lto_tlv = nfc_llcp_build_tlv(LLCP_TLV_LTO, &local->lto, 1, <o_length); + if (!lto_tlv) { + ret = -ENOMEM; + goto out; + } gb_len += lto_length; pr_debug("Local wks 0x%lx\n", local->local_wks); wks_tlv = nfc_llcp_build_tlv(LLCP_TLV_WKS, (u8 *)&wks, 2, &wks_length); + if (!wks_tlv) { + ret = -ENOMEM; + goto out; + } gb_len += wks_length; miux_tlv = nfc_llcp_build_tlv(LLCP_TLV_MIUX, (u8 *)&local->miux, 0, &miux_length); + if (!miux_tlv) { + ret = -ENOMEM; + goto out; + } gb_len += miux_length; gb_len += ARRAY_SIZE(llcp_magic); diff --git a/net/openvswitch/Kconfig b/net/openvswitch/Kconfig index 89da9512ec1e..ac1cc6e38170 100644 --- a/net/openvswitch/Kconfig +++ b/net/openvswitch/Kconfig @@ -8,8 +8,6 @@ config OPENVSWITCH depends on !NF_CONNTRACK || \ (NF_CONNTRACK && ((!NF_DEFRAG_IPV6 || NF_DEFRAG_IPV6) && \ (!NF_NAT || NF_NAT) && \ - (!NF_NAT_IPV4 || NF_NAT_IPV4) && \ - (!NF_NAT_IPV6 || NF_NAT_IPV6) && \ (!NETFILTER_CONNCOUNT || NETFILTER_CONNCOUNT))) select LIBCRC32C select MPLS diff --git a/net/openvswitch/conntrack.c b/net/openvswitch/conntrack.c index cd94f925495a..1b6896896fff 100644 --- a/net/openvswitch/conntrack.c +++ b/net/openvswitch/conntrack.c @@ -29,9 +29,7 @@ #include <net/ipv6_frag.h> #ifdef CONFIG_NF_NAT_NEEDED -#include <linux/netfilter/nf_nat.h> -#include <net/netfilter/nf_nat_core.h> -#include <net/netfilter/nf_nat_l3proto.h> +#include <net/netfilter/nf_nat.h> #endif #include "datapath.h" @@ -622,7 +620,7 @@ ovs_ct_find_existing(struct net *net, const struct nf_conntrack_zone *zone, if (natted) { struct nf_conntrack_tuple inverse; - if (!nf_ct_invert_tuplepr(&inverse, &tuple)) { + if (!nf_ct_invert_tuple(&inverse, &tuple)) { pr_debug("ovs_ct_find_existing: Inversion failed!\n"); return NULL; } @@ -745,14 +743,14 @@ static int ovs_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct, switch (ctinfo) { case IP_CT_RELATED: case IP_CT_RELATED_REPLY: - if (IS_ENABLED(CONFIG_NF_NAT_IPV4) && + if (IS_ENABLED(CONFIG_NF_NAT) && skb->protocol == htons(ETH_P_IP) && ip_hdr(skb)->protocol == IPPROTO_ICMP) { if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, hooknum)) err = NF_DROP; goto push; - } else if (IS_ENABLED(CONFIG_NF_NAT_IPV6) && + } else if (IS_ENABLED(CONFIG_IPV6) && skb->protocol == htons(ETH_P_IPV6)) { __be16 frag_off; u8 nexthdr = ipv6_hdr(skb)->nexthdr; @@ -1673,7 +1671,7 @@ static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info, } if (info->range.flags & NF_NAT_RANGE_MAP_IPS) { - if (IS_ENABLED(CONFIG_NF_NAT_IPV4) && + if (IS_ENABLED(CONFIG_NF_NAT) && info->family == NFPROTO_IPV4) { if (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MIN, info->range.min_addr.ip) || @@ -1682,7 +1680,7 @@ static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info, (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MAX, info->range.max_addr.ip)))) return false; - } else if (IS_ENABLED(CONFIG_NF_NAT_IPV6) && + } else if (IS_ENABLED(CONFIG_IPV6) && info->family == NFPROTO_IPV6) { if (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MIN, &info->range.min_addr.in6) || diff --git a/net/openvswitch/flow.h b/net/openvswitch/flow.h index ba01fc4270bd..5b8e5bd7457b 100644 --- a/net/openvswitch/flow.h +++ b/net/openvswitch/flow.h @@ -30,7 +30,6 @@ #include <linux/in6.h> #include <linux/jiffies.h> #include <linux/time.h> -#include <linux/flex_array.h> #include <linux/cpumask.h> #include <net/inet_ecn.h> #include <net/ip_tunnels.h> diff --git a/net/openvswitch/flow_netlink.h b/net/openvswitch/flow_netlink.h index 6657606b2b47..66f9553758a5 100644 --- a/net/openvswitch/flow_netlink.h +++ b/net/openvswitch/flow_netlink.h @@ -30,7 +30,6 @@ #include <linux/in6.h> #include <linux/jiffies.h> #include <linux/time.h> -#include <linux/flex_array.h> #include <net/inet_ecn.h> #include <net/ip_tunnels.h> diff --git a/net/openvswitch/flow_table.c b/net/openvswitch/flow_table.c index 80ea2a71852e..cfb0098c9a01 100644 --- a/net/openvswitch/flow_table.c +++ b/net/openvswitch/flow_table.c @@ -111,29 +111,6 @@ int ovs_flow_tbl_count(const struct flow_table *table) return table->count; } -static struct flex_array *alloc_buckets(unsigned int n_buckets) -{ - struct flex_array *buckets; - int i, err; - - buckets = flex_array_alloc(sizeof(struct hlist_head), - n_buckets, GFP_KERNEL); - if (!buckets) - return NULL; - - err = flex_array_prealloc(buckets, 0, n_buckets, GFP_KERNEL); - if (err) { - flex_array_free(buckets); - return NULL; - } - - for (i = 0; i < n_buckets; i++) - INIT_HLIST_HEAD((struct hlist_head *) - flex_array_get(buckets, i)); - - return buckets; -} - static void flow_free(struct sw_flow *flow) { int cpu; @@ -168,31 +145,30 @@ void ovs_flow_free(struct sw_flow *flow, bool deferred) flow_free(flow); } -static void free_buckets(struct flex_array *buckets) -{ - flex_array_free(buckets); -} - - static void __table_instance_destroy(struct table_instance *ti) { - free_buckets(ti->buckets); + kvfree(ti->buckets); kfree(ti); } static struct table_instance *table_instance_alloc(int new_size) { struct table_instance *ti = kmalloc(sizeof(*ti), GFP_KERNEL); + int i; if (!ti) return NULL; - ti->buckets = alloc_buckets(new_size); - + ti->buckets = kvmalloc_array(new_size, sizeof(struct hlist_head), + GFP_KERNEL); if (!ti->buckets) { kfree(ti); return NULL; } + + for (i = 0; i < new_size; i++) + INIT_HLIST_HEAD(&ti->buckets[i]); + ti->n_buckets = new_size; ti->node_ver = 0; ti->keep_flows = false; @@ -249,7 +225,7 @@ static void table_instance_destroy(struct table_instance *ti, for (i = 0; i < ti->n_buckets; i++) { struct sw_flow *flow; - struct hlist_head *head = flex_array_get(ti->buckets, i); + struct hlist_head *head = &ti->buckets[i]; struct hlist_node *n; int ver = ti->node_ver; int ufid_ver = ufid_ti->node_ver; @@ -294,7 +270,7 @@ struct sw_flow *ovs_flow_tbl_dump_next(struct table_instance *ti, ver = ti->node_ver; while (*bucket < ti->n_buckets) { i = 0; - head = flex_array_get(ti->buckets, *bucket); + head = &ti->buckets[*bucket]; hlist_for_each_entry_rcu(flow, head, flow_table.node[ver]) { if (i < *last) { i++; @@ -313,8 +289,7 @@ struct sw_flow *ovs_flow_tbl_dump_next(struct table_instance *ti, static struct hlist_head *find_bucket(struct table_instance *ti, u32 hash) { hash = jhash_1word(hash, ti->hash_seed); - return flex_array_get(ti->buckets, - (hash & (ti->n_buckets - 1))); + return &ti->buckets[hash & (ti->n_buckets - 1)]; } static void table_instance_insert(struct table_instance *ti, @@ -347,9 +322,7 @@ static void flow_table_copy_flows(struct table_instance *old, /* Insert in new table. */ for (i = 0; i < old->n_buckets; i++) { struct sw_flow *flow; - struct hlist_head *head; - - head = flex_array_get(old->buckets, i); + struct hlist_head *head = &old->buckets[i]; if (ufid) hlist_for_each_entry(flow, head, diff --git a/net/openvswitch/flow_table.h b/net/openvswitch/flow_table.h index 2dd9900f533d..de5ec6cf5174 100644 --- a/net/openvswitch/flow_table.h +++ b/net/openvswitch/flow_table.h @@ -29,7 +29,6 @@ #include <linux/in6.h> #include <linux/jiffies.h> #include <linux/time.h> -#include <linux/flex_array.h> #include <net/inet_ecn.h> #include <net/ip_tunnels.h> @@ -37,7 +36,7 @@ #include "flow.h" struct table_instance { - struct flex_array *buckets; + struct hlist_head *buckets; unsigned int n_buckets; struct rcu_head rcu; int node_ver; diff --git a/net/openvswitch/meter.c b/net/openvswitch/meter.c index c038e021a591..43849d752a1e 100644 --- a/net/openvswitch/meter.c +++ b/net/openvswitch/meter.c @@ -206,8 +206,7 @@ static struct dp_meter *dp_meter_create(struct nlattr **a) return ERR_PTR(-EINVAL); /* Allocate and set up the meter before locking anything. */ - meter = kzalloc(n_bands * sizeof(struct dp_meter_band) + - sizeof(*meter), GFP_KERNEL); + meter = kzalloc(struct_size(meter, bands, n_bands), GFP_KERNEL); if (!meter) return ERR_PTR(-ENOMEM); diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c index 3b1a78906bc0..8376bc1c1508 100644 --- a/net/packet/af_packet.c +++ b/net/packet/af_packet.c @@ -1850,6 +1850,15 @@ oom: return 0; } +static void packet_parse_headers(struct sk_buff *skb, struct socket *sock) +{ + if (!skb->protocol && sock->type == SOCK_RAW) { + skb_reset_mac_header(skb); + skb->protocol = dev_parse_header_protocol(skb); + } + + skb_probe_transport_header(skb); +} /* * Output a raw packet to a device layer. This bypasses all the other @@ -1970,7 +1979,7 @@ retry: if (unlikely(extra_len == 4)) skb->no_fcs = 1; - skb_probe_transport_header(skb, 0); + packet_parse_headers(skb, sock); dev_queue_xmit(skb); rcu_read_unlock(); @@ -2404,15 +2413,6 @@ static void tpacket_destruct_skb(struct sk_buff *skb) sock_wfree(skb); } -static void tpacket_set_protocol(const struct net_device *dev, - struct sk_buff *skb) -{ - if (dev->type == ARPHRD_ETHER) { - skb_reset_mac_header(skb); - skb->protocol = eth_hdr(skb)->h_proto; - } -} - static int __packet_snd_vnet_parse(struct virtio_net_hdr *vnet_hdr, size_t len) { if ((vnet_hdr->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) && @@ -2483,8 +2483,6 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb, return err; if (!dev_validate_header(dev, skb->data, hdrlen)) return -EINVAL; - if (!skb->protocol) - tpacket_set_protocol(dev, skb); data += hdrlen; to_write -= hdrlen; @@ -2519,7 +2517,7 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb, len = ((to_write > len_max) ? len_max : to_write); } - skb_probe_transport_header(skb, 0); + packet_parse_headers(skb, sock); return tp_len; } @@ -2925,7 +2923,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len) virtio_net_hdr_set_proto(skb, &vnet_hdr); } - skb_probe_transport_header(skb, reserve); + packet_parse_headers(skb, sock); if (unlikely(extra_len == 4)) skb->no_fcs = 1; @@ -4292,7 +4290,7 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u, rb->frames_per_block = req->tp_block_size / req->tp_frame_size; if (unlikely(rb->frames_per_block == 0)) goto out; - if (unlikely(req->tp_block_size > UINT_MAX / req->tp_block_nr)) + if (unlikely(rb->frames_per_block > UINT_MAX / req->tp_block_nr)) goto out; if (unlikely((rb->frames_per_block * req->tp_block_nr) != req->tp_frame_nr)) diff --git a/net/phonet/pep.c b/net/phonet/pep.c index 9fc76b19cd3c..db3473540303 100644 --- a/net/phonet/pep.c +++ b/net/phonet/pep.c @@ -132,7 +132,7 @@ static int pep_indicate(struct sock *sk, u8 id, u8 code, ph->utid = 0; ph->message_id = id; ph->pipe_handle = pn->pipe_handle; - ph->data[0] = code; + ph->error_code = code; return pn_skb_send(sk, skb, NULL); } @@ -153,7 +153,7 @@ static int pipe_handler_request(struct sock *sk, u8 id, u8 code, ph->utid = id; /* whatever */ ph->message_id = id; ph->pipe_handle = pn->pipe_handle; - ph->data[0] = code; + ph->error_code = code; return pn_skb_send(sk, skb, NULL); } @@ -208,7 +208,7 @@ static int pep_ctrlreq_error(struct sock *sk, struct sk_buff *oskb, u8 code, struct pnpipehdr *ph; struct sockaddr_pn dst; u8 data[4] = { - oph->data[0], /* PEP type */ + oph->pep_type, /* PEP type */ code, /* error code, at an unusual offset */ PAD, PAD, }; @@ -221,7 +221,7 @@ static int pep_ctrlreq_error(struct sock *sk, struct sk_buff *oskb, u8 code, ph->utid = oph->utid; ph->message_id = PNS_PEP_CTRL_RESP; ph->pipe_handle = oph->pipe_handle; - ph->data[0] = oph->data[1]; /* CTRL id */ + ph->data0 = oph->data[0]; /* CTRL id */ pn_skb_get_src_sockaddr(oskb, &dst); return pn_skb_send(sk, skb, &dst); @@ -272,17 +272,17 @@ static int pipe_rcv_status(struct sock *sk, struct sk_buff *skb) return -EINVAL; hdr = pnp_hdr(skb); - if (hdr->data[0] != PN_PEP_TYPE_COMMON) { + if (hdr->pep_type != PN_PEP_TYPE_COMMON) { net_dbg_ratelimited("Phonet unknown PEP type: %u\n", - (unsigned int)hdr->data[0]); + (unsigned int)hdr->pep_type); return -EOPNOTSUPP; } - switch (hdr->data[1]) { + switch (hdr->data[0]) { case PN_PEP_IND_FLOW_CONTROL: switch (pn->tx_fc) { case PN_LEGACY_FLOW_CONTROL: - switch (hdr->data[4]) { + switch (hdr->data[3]) { case PEP_IND_BUSY: atomic_set(&pn->tx_credits, 0); break; @@ -292,7 +292,7 @@ static int pipe_rcv_status(struct sock *sk, struct sk_buff *skb) } break; case PN_ONE_CREDIT_FLOW_CONTROL: - if (hdr->data[4] == PEP_IND_READY) + if (hdr->data[3] == PEP_IND_READY) atomic_set(&pn->tx_credits, wake = 1); break; } @@ -301,12 +301,12 @@ static int pipe_rcv_status(struct sock *sk, struct sk_buff *skb) case PN_PEP_IND_ID_MCFC_GRANT_CREDITS: if (pn->tx_fc != PN_MULTI_CREDIT_FLOW_CONTROL) break; - atomic_add(wake = hdr->data[4], &pn->tx_credits); + atomic_add(wake = hdr->data[3], &pn->tx_credits); break; default: net_dbg_ratelimited("Phonet unknown PEP indication: %u\n", - (unsigned int)hdr->data[1]); + (unsigned int)hdr->data[0]); return -EOPNOTSUPP; } if (wake) @@ -318,7 +318,7 @@ static int pipe_rcv_created(struct sock *sk, struct sk_buff *skb) { struct pep_sock *pn = pep_sk(sk); struct pnpipehdr *hdr = pnp_hdr(skb); - u8 n_sb = hdr->data[0]; + u8 n_sb = hdr->data0; pn->rx_fc = pn->tx_fc = PN_LEGACY_FLOW_CONTROL; __skb_pull(skb, sizeof(*hdr)); @@ -506,7 +506,7 @@ static int pep_connresp_rcv(struct sock *sk, struct sk_buff *skb) return -ECONNREFUSED; /* Parse sub-blocks */ - n_sb = hdr->data[4]; + n_sb = hdr->data[3]; while (n_sb > 0) { u8 type, buf[6], len = sizeof(buf); const u8 *data = pep_get_sb(skb, &type, &len, buf); @@ -739,7 +739,7 @@ static int pipe_do_remove(struct sock *sk) ph->utid = 0; ph->message_id = PNS_PIPE_REMOVE_REQ; ph->pipe_handle = pn->pipe_handle; - ph->data[0] = PAD; + ph->data0 = PAD; return pn_skb_send(sk, skb, NULL); } @@ -817,7 +817,7 @@ static struct sock *pep_sock_accept(struct sock *sk, int flags, int *errp, peer_type = hdr->other_pep_type << 8; /* Parse sub-blocks (options) */ - n_sb = hdr->data[4]; + n_sb = hdr->data[3]; while (n_sb > 0) { u8 type, buf[1], len = sizeof(buf); const u8 *data = pep_get_sb(skb, &type, &len, buf); @@ -1109,7 +1109,7 @@ static int pipe_skb_send(struct sock *sk, struct sk_buff *skb) ph->utid = 0; if (pn->aligned) { ph->message_id = PNS_PIPE_ALIGNED_DATA; - ph->data[0] = 0; /* padding */ + ph->data0 = 0; /* padding */ } else ph->message_id = PNS_PIPE_DATA; ph->pipe_handle = pn->pipe_handle; diff --git a/net/qrtr/qrtr.c b/net/qrtr/qrtr.c index 86e1e37eb4e8..b37e6e0a1026 100644 --- a/net/qrtr/qrtr.c +++ b/net/qrtr/qrtr.c @@ -15,6 +15,7 @@ #include <linux/netlink.h> #include <linux/qrtr.h> #include <linux/termios.h> /* For TIOCINQ/OUTQ */ +#include <linux/numa.h> #include <net/sock.h> @@ -101,7 +102,7 @@ static inline struct qrtr_sock *qrtr_sk(struct sock *sk) return container_of(sk, struct qrtr_sock, sk); } -static unsigned int qrtr_local_nid = -1; +static unsigned int qrtr_local_nid = NUMA_NO_NODE; /* for node ids */ static RADIX_TREE(qrtr_nodes, GFP_KERNEL); diff --git a/net/rds/af_rds.c b/net/rds/af_rds.c index 65387e1e6964..d6cc97fbbbb0 100644 --- a/net/rds/af_rds.c +++ b/net/rds/af_rds.c @@ -254,7 +254,40 @@ static __poll_t rds_poll(struct file *file, struct socket *sock, static int rds_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) { - return -ENOIOCTLCMD; + struct rds_sock *rs = rds_sk_to_rs(sock->sk); + rds_tos_t utos, tos = 0; + + switch (cmd) { + case SIOCRDSSETTOS: + if (get_user(utos, (rds_tos_t __user *)arg)) + return -EFAULT; + + if (rs->rs_transport && + rs->rs_transport->get_tos_map) + tos = rs->rs_transport->get_tos_map(utos); + else + return -ENOIOCTLCMD; + + spin_lock_bh(&rds_sock_lock); + if (rs->rs_tos || rs->rs_conn) { + spin_unlock_bh(&rds_sock_lock); + return -EINVAL; + } + rs->rs_tos = tos; + spin_unlock_bh(&rds_sock_lock); + break; + case SIOCRDSGETTOS: + spin_lock_bh(&rds_sock_lock); + tos = rs->rs_tos; + spin_unlock_bh(&rds_sock_lock); + if (put_user(tos, (rds_tos_t __user *)arg)) + return -EFAULT; + break; + default: + return -ENOIOCTLCMD; + } + + return 0; } static int rds_cancel_sent_to(struct rds_sock *rs, char __user *optval, @@ -348,7 +381,7 @@ static int rds_set_transport(struct rds_sock *rs, char __user *optval, } static int rds_enable_recvtstamp(struct sock *sk, char __user *optval, - int optlen) + int optlen, int optname) { int val, valbool; @@ -360,6 +393,9 @@ static int rds_enable_recvtstamp(struct sock *sk, char __user *optval, valbool = val ? 1 : 0; + if (optname == SO_TIMESTAMP_NEW) + sock_set_flag(sk, SOCK_TSTAMP_NEW); + if (valbool) sock_set_flag(sk, SOCK_RCVTSTAMP); else @@ -430,9 +466,10 @@ static int rds_setsockopt(struct socket *sock, int level, int optname, ret = rds_set_transport(rs, optval, optlen); release_sock(sock->sk); break; - case SO_TIMESTAMP: + case SO_TIMESTAMP_OLD: + case SO_TIMESTAMP_NEW: lock_sock(sock->sk); - ret = rds_enable_recvtstamp(sock->sk, optval, optlen); + ret = rds_enable_recvtstamp(sock->sk, optval, optlen, optname); release_sock(sock->sk); break; case SO_RDS_MSG_RXPATH_LATENCY: @@ -646,6 +683,8 @@ static int __rds_create(struct socket *sock, struct sock *sk, int protocol) spin_lock_init(&rs->rs_rdma_lock); rs->rs_rdma_keys = RB_ROOT; rs->rs_rx_traces = 0; + rs->rs_tos = 0; + rs->rs_conn = NULL; spin_lock_bh(&rds_sock_lock); list_add_tail(&rs->rs_item, &rds_sock_list); diff --git a/net/rds/bind.c b/net/rds/bind.c index 762d2c6788a3..17c9d9f0c848 100644 --- a/net/rds/bind.c +++ b/net/rds/bind.c @@ -78,10 +78,10 @@ struct rds_sock *rds_find_bound(const struct in6_addr *addr, __be16 port, __rds_create_bind_key(key, addr, port, scope_id); rcu_read_lock(); rs = rhashtable_lookup(&bind_hash_table, key, ht_parms); - if (rs && !sock_flag(rds_rs_to_sk(rs), SOCK_DEAD)) - rds_sock_addref(rs); - else + if (rs && (sock_flag(rds_rs_to_sk(rs), SOCK_DEAD) || + !refcount_inc_not_zero(&rds_rs_to_sk(rs)->sk_refcnt))) rs = NULL; + rcu_read_unlock(); rdsdebug("returning rs %p for %pI6c:%u\n", rs, addr, diff --git a/net/rds/connection.c b/net/rds/connection.c index 3bd2f4a5a30d..7ea134f9a825 100644 --- a/net/rds/connection.c +++ b/net/rds/connection.c @@ -84,7 +84,7 @@ static struct rds_connection *rds_conn_lookup(struct net *net, const struct in6_addr *laddr, const struct in6_addr *faddr, struct rds_transport *trans, - int dev_if) + u8 tos, int dev_if) { struct rds_connection *conn, *ret = NULL; @@ -92,6 +92,7 @@ static struct rds_connection *rds_conn_lookup(struct net *net, if (ipv6_addr_equal(&conn->c_faddr, faddr) && ipv6_addr_equal(&conn->c_laddr, laddr) && conn->c_trans == trans && + conn->c_tos == tos && net == rds_conn_net(conn) && conn->c_dev_if == dev_if) { ret = conn; @@ -139,6 +140,7 @@ static void __rds_conn_path_init(struct rds_connection *conn, atomic_set(&cp->cp_state, RDS_CONN_DOWN); cp->cp_send_gen = 0; cp->cp_reconnect_jiffies = 0; + cp->cp_conn->c_proposed_version = RDS_PROTOCOL_VERSION; INIT_DELAYED_WORK(&cp->cp_send_w, rds_send_worker); INIT_DELAYED_WORK(&cp->cp_recv_w, rds_recv_worker); INIT_DELAYED_WORK(&cp->cp_conn_w, rds_connect_worker); @@ -159,7 +161,7 @@ static struct rds_connection *__rds_conn_create(struct net *net, const struct in6_addr *laddr, const struct in6_addr *faddr, struct rds_transport *trans, - gfp_t gfp, + gfp_t gfp, u8 tos, int is_outgoing, int dev_if) { @@ -171,7 +173,7 @@ static struct rds_connection *__rds_conn_create(struct net *net, int npaths = (trans->t_mp_capable ? RDS_MPATH_WORKERS : 1); rcu_read_lock(); - conn = rds_conn_lookup(net, head, laddr, faddr, trans, dev_if); + conn = rds_conn_lookup(net, head, laddr, faddr, trans, tos, dev_if); if (conn && conn->c_loopback && conn->c_trans != &rds_loop_transport && @@ -205,6 +207,7 @@ static struct rds_connection *__rds_conn_create(struct net *net, conn->c_isv6 = !ipv6_addr_v4mapped(laddr); conn->c_faddr = *faddr; conn->c_dev_if = dev_if; + conn->c_tos = tos; #if IS_ENABLED(CONFIG_IPV6) /* If the local address is link local, set c_bound_if to be the @@ -297,7 +300,7 @@ static struct rds_connection *__rds_conn_create(struct net *net, struct rds_connection *found; found = rds_conn_lookup(net, head, laddr, faddr, trans, - dev_if); + tos, dev_if); if (found) { struct rds_conn_path *cp; int i; @@ -332,10 +335,10 @@ out: struct rds_connection *rds_conn_create(struct net *net, const struct in6_addr *laddr, const struct in6_addr *faddr, - struct rds_transport *trans, gfp_t gfp, - int dev_if) + struct rds_transport *trans, u8 tos, + gfp_t gfp, int dev_if) { - return __rds_conn_create(net, laddr, faddr, trans, gfp, 0, dev_if); + return __rds_conn_create(net, laddr, faddr, trans, gfp, tos, 0, dev_if); } EXPORT_SYMBOL_GPL(rds_conn_create); @@ -343,9 +346,9 @@ struct rds_connection *rds_conn_create_outgoing(struct net *net, const struct in6_addr *laddr, const struct in6_addr *faddr, struct rds_transport *trans, - gfp_t gfp, int dev_if) + u8 tos, gfp_t gfp, int dev_if) { - return __rds_conn_create(net, laddr, faddr, trans, gfp, 1, dev_if); + return __rds_conn_create(net, laddr, faddr, trans, gfp, tos, 1, dev_if); } EXPORT_SYMBOL_GPL(rds_conn_create_outgoing); diff --git a/net/rds/ib.c b/net/rds/ib.c index 9d7b7586f240..2da9b75bad16 100644 --- a/net/rds/ib.c +++ b/net/rds/ib.c @@ -301,6 +301,7 @@ static int rds_ib_conn_info_visitor(struct rds_connection *conn, iinfo->src_addr = conn->c_laddr.s6_addr32[3]; iinfo->dst_addr = conn->c_faddr.s6_addr32[3]; + iinfo->tos = conn->c_tos; memset(&iinfo->src_gid, 0, sizeof(iinfo->src_gid)); memset(&iinfo->dst_gid, 0, sizeof(iinfo->dst_gid)); @@ -514,6 +515,15 @@ void rds_ib_exit(void) rds_ib_mr_exit(); } +static u8 rds_ib_get_tos_map(u8 tos) +{ + /* 1:1 user to transport map for RDMA transport. + * In future, if custom map is desired, hook can export + * user configurable map. + */ + return tos; +} + struct rds_transport rds_ib_transport = { .laddr_check = rds_ib_laddr_check, .xmit_path_complete = rds_ib_xmit_path_complete, @@ -536,6 +546,7 @@ struct rds_transport rds_ib_transport = { .sync_mr = rds_ib_sync_mr, .free_mr = rds_ib_free_mr, .flush_mrs = rds_ib_flush_mrs, + .get_tos_map = rds_ib_get_tos_map, .t_owner = THIS_MODULE, .t_name = "infiniband", .t_unloading = rds_ib_is_unloading, diff --git a/net/rds/ib.h b/net/rds/ib.h index 71ff356ee702..67a715b076ca 100644 --- a/net/rds/ib.h +++ b/net/rds/ib.h @@ -67,7 +67,9 @@ struct rds_ib_conn_priv_cmn { u8 ricpc_protocol_major; u8 ricpc_protocol_minor; __be16 ricpc_protocol_minor_mask; /* bitmask */ - __be32 ricpc_reserved1; + u8 ricpc_dp_toss; + u8 ripc_reserved1; + __be16 ripc_reserved2; __be64 ricpc_ack_seq; __be32 ricpc_credit; /* non-zero enables flow ctl */ }; @@ -331,10 +333,8 @@ static inline void rds_ib_dma_sync_sg_for_cpu(struct ib_device *dev, unsigned int i; for_each_sg(sglist, sg, sg_dma_len, i) { - ib_dma_sync_single_for_cpu(dev, - ib_sg_dma_address(dev, sg), - ib_sg_dma_len(dev, sg), - direction); + ib_dma_sync_single_for_cpu(dev, sg_dma_address(sg), + sg_dma_len(sg), direction); } } #define ib_dma_sync_sg_for_cpu rds_ib_dma_sync_sg_for_cpu @@ -348,10 +348,8 @@ static inline void rds_ib_dma_sync_sg_for_device(struct ib_device *dev, unsigned int i; for_each_sg(sglist, sg, sg_dma_len, i) { - ib_dma_sync_single_for_device(dev, - ib_sg_dma_address(dev, sg), - ib_sg_dma_len(dev, sg), - direction); + ib_dma_sync_single_for_device(dev, sg_dma_address(sg), + sg_dma_len(sg), direction); } } #define ib_dma_sync_sg_for_device rds_ib_dma_sync_sg_for_device diff --git a/net/rds/ib_cm.c b/net/rds/ib_cm.c index bfbb31f0c7fd..66c6eb56072b 100644 --- a/net/rds/ib_cm.c +++ b/net/rds/ib_cm.c @@ -133,23 +133,24 @@ void rds_ib_cm_connect_complete(struct rds_connection *conn, struct rdma_cm_even rds_ib_set_flow_control(conn, be32_to_cpu(credit)); } - if (conn->c_version < RDS_PROTOCOL(3, 1)) { - pr_notice("RDS/IB: Connection <%pI6c,%pI6c> version %u.%u no longer supported\n", - &conn->c_laddr, &conn->c_faddr, - RDS_PROTOCOL_MAJOR(conn->c_version), - RDS_PROTOCOL_MINOR(conn->c_version)); - set_bit(RDS_DESTROY_PENDING, &conn->c_path[0].cp_flags); - rds_conn_destroy(conn); - return; - } else { - pr_notice("RDS/IB: %s conn connected <%pI6c,%pI6c> version %u.%u%s\n", - ic->i_active_side ? "Active" : "Passive", - &conn->c_laddr, &conn->c_faddr, - RDS_PROTOCOL_MAJOR(conn->c_version), - RDS_PROTOCOL_MINOR(conn->c_version), - ic->i_flowctl ? ", flow control" : ""); + if (conn->c_version < RDS_PROTOCOL_VERSION) { + if (conn->c_version != RDS_PROTOCOL_COMPAT_VERSION) { + pr_notice("RDS/IB: Connection <%pI6c,%pI6c> version %u.%u no longer supported\n", + &conn->c_laddr, &conn->c_faddr, + RDS_PROTOCOL_MAJOR(conn->c_version), + RDS_PROTOCOL_MINOR(conn->c_version)); + rds_conn_destroy(conn); + return; + } } + pr_notice("RDS/IB: %s conn connected <%pI6c,%pI6c,%d> version %u.%u%s\n", + ic->i_active_side ? "Active" : "Passive", + &conn->c_laddr, &conn->c_faddr, conn->c_tos, + RDS_PROTOCOL_MAJOR(conn->c_version), + RDS_PROTOCOL_MINOR(conn->c_version), + ic->i_flowctl ? ", flow control" : ""); + atomic_set(&ic->i_cq_quiesce, 0); /* Init rings and fill recv. this needs to wait until protocol @@ -184,6 +185,7 @@ void rds_ib_cm_connect_complete(struct rds_connection *conn, struct rdma_cm_even NULL); } + conn->c_proposed_version = conn->c_version; rds_connect_complete(conn); } @@ -220,6 +222,7 @@ static void rds_ib_cm_fill_conn_param(struct rds_connection *conn, cpu_to_be16(RDS_IB_SUPPORTED_PROTOCOLS); dp->ricp_v6.dp_ack_seq = cpu_to_be64(rds_ib_piggyb_ack(ic)); + dp->ricp_v6.dp_cmn.ricpc_dp_toss = conn->c_tos; conn_param->private_data = &dp->ricp_v6; conn_param->private_data_len = sizeof(dp->ricp_v6); @@ -234,6 +237,7 @@ static void rds_ib_cm_fill_conn_param(struct rds_connection *conn, cpu_to_be16(RDS_IB_SUPPORTED_PROTOCOLS); dp->ricp_v4.dp_ack_seq = cpu_to_be64(rds_ib_piggyb_ack(ic)); + dp->ricp_v4.dp_cmn.ricpc_dp_toss = conn->c_tos; conn_param->private_data = &dp->ricp_v4; conn_param->private_data_len = sizeof(dp->ricp_v4); @@ -389,10 +393,9 @@ static void rds_ib_qp_event_handler(struct ib_event *event, void *data) rdma_notify(ic->i_cm_id, IB_EVENT_COMM_EST); break; default: - rdsdebug("Fatal QP Event %u (%s) " - "- connection %pI6c->%pI6c, reconnecting\n", - event->event, ib_event_msg(event->event), - &conn->c_laddr, &conn->c_faddr); + rdsdebug("Fatal QP Event %u (%s) - connection %pI6c->%pI6c, reconnecting\n", + event->event, ib_event_msg(event->event), + &conn->c_laddr, &conn->c_faddr); rds_conn_drop(conn); break; } @@ -660,13 +663,16 @@ static u32 rds_ib_protocol_compatible(struct rdma_cm_event *event, bool isv6) /* Even if len is crap *now* I still want to check it. -ASG */ if (event->param.conn.private_data_len < data_len || major == 0) - return RDS_PROTOCOL_3_0; + return RDS_PROTOCOL_4_0; common = be16_to_cpu(mask) & RDS_IB_SUPPORTED_PROTOCOLS; - if (major == 3 && common) { - version = RDS_PROTOCOL_3_0; + if (major == 4 && common) { + version = RDS_PROTOCOL_4_0; while ((common >>= 1) != 0) version++; + } else if (RDS_PROTOCOL_COMPAT_VERSION == + RDS_PROTOCOL(major, minor)) { + version = RDS_PROTOCOL_COMPAT_VERSION; } else { if (isv6) printk_ratelimited(KERN_NOTICE "RDS: Connection from %pI6c using incompatible protocol version %u.%u\n", @@ -729,8 +735,10 @@ int rds_ib_cm_handle_connect(struct rdma_cm_id *cm_id, /* Check whether the remote protocol version matches ours. */ version = rds_ib_protocol_compatible(event, isv6); - if (!version) + if (!version) { + err = RDS_RDMA_REJ_INCOMPAT; goto out; + } dp = event->param.conn.private_data; if (isv6) { @@ -771,15 +779,16 @@ int rds_ib_cm_handle_connect(struct rdma_cm_id *cm_id, daddr6 = &d_mapped_addr; } - rdsdebug("saddr %pI6c daddr %pI6c RDSv%u.%u lguid 0x%llx fguid " - "0x%llx\n", saddr6, daddr6, - RDS_PROTOCOL_MAJOR(version), RDS_PROTOCOL_MINOR(version), + rdsdebug("saddr %pI6c daddr %pI6c RDSv%u.%u lguid 0x%llx fguid 0x%llx, tos:%d\n", + saddr6, daddr6, RDS_PROTOCOL_MAJOR(version), + RDS_PROTOCOL_MINOR(version), (unsigned long long)be64_to_cpu(lguid), - (unsigned long long)be64_to_cpu(fguid)); + (unsigned long long)be64_to_cpu(fguid), dp_cmn->ricpc_dp_toss); /* RDS/IB is not currently netns aware, thus init_net */ conn = rds_conn_create(&init_net, daddr6, saddr6, - &rds_ib_transport, GFP_KERNEL, ifindex); + &rds_ib_transport, dp_cmn->ricpc_dp_toss, + GFP_KERNEL, ifindex); if (IS_ERR(conn)) { rdsdebug("rds_conn_create failed (%ld)\n", PTR_ERR(conn)); conn = NULL; @@ -846,7 +855,7 @@ out: if (conn) mutex_unlock(&conn->c_cm_lock); if (err) - rdma_reject(cm_id, NULL, 0); + rdma_reject(cm_id, &err, sizeof(int)); return destroy; } @@ -861,7 +870,7 @@ int rds_ib_cm_initiate_connect(struct rdma_cm_id *cm_id, bool isv6) /* If the peer doesn't do protocol negotiation, we must * default to RDSv3.0 */ - rds_ib_set_protocol(conn, RDS_PROTOCOL_3_0); + rds_ib_set_protocol(conn, RDS_PROTOCOL_4_1); ic->i_flowctl = rds_ib_sysctl_flow_control; /* advertise flow control */ ret = rds_ib_setup_qp(conn); @@ -870,7 +879,8 @@ int rds_ib_cm_initiate_connect(struct rdma_cm_id *cm_id, bool isv6) goto out; } - rds_ib_cm_fill_conn_param(conn, &conn_param, &dp, RDS_PROTOCOL_VERSION, + rds_ib_cm_fill_conn_param(conn, &conn_param, &dp, + conn->c_proposed_version, UINT_MAX, UINT_MAX, isv6); ret = rdma_connect(cm_id, &conn_param); if (ret) diff --git a/net/rds/ib_fmr.c b/net/rds/ib_fmr.c index e0f70c4051b6..31cf37da4510 100644 --- a/net/rds/ib_fmr.c +++ b/net/rds/ib_fmr.c @@ -108,8 +108,8 @@ static int rds_ib_map_fmr(struct rds_ib_device *rds_ibdev, page_cnt = 0; for (i = 0; i < sg_dma_len; ++i) { - unsigned int dma_len = ib_sg_dma_len(dev, &scat[i]); - u64 dma_addr = ib_sg_dma_address(dev, &scat[i]); + unsigned int dma_len = sg_dma_len(&scat[i]); + u64 dma_addr = sg_dma_address(&scat[i]); if (dma_addr & ~PAGE_MASK) { if (i > 0) { @@ -148,8 +148,8 @@ static int rds_ib_map_fmr(struct rds_ib_device *rds_ibdev, page_cnt = 0; for (i = 0; i < sg_dma_len; ++i) { - unsigned int dma_len = ib_sg_dma_len(dev, &scat[i]); - u64 dma_addr = ib_sg_dma_address(dev, &scat[i]); + unsigned int dma_len = sg_dma_len(&scat[i]); + u64 dma_addr = sg_dma_address(&scat[i]); for (j = 0; j < dma_len; j += PAGE_SIZE) dma_pages[page_cnt++] = diff --git a/net/rds/ib_frmr.c b/net/rds/ib_frmr.c index 6431a023ac89..688dcd68d4ea 100644 --- a/net/rds/ib_frmr.c +++ b/net/rds/ib_frmr.c @@ -181,8 +181,8 @@ static int rds_ib_map_frmr(struct rds_ib_device *rds_ibdev, ret = -EINVAL; for (i = 0; i < ibmr->sg_dma_len; ++i) { - unsigned int dma_len = ib_sg_dma_len(dev, &ibmr->sg[i]); - u64 dma_addr = ib_sg_dma_address(dev, &ibmr->sg[i]); + unsigned int dma_len = sg_dma_len(&ibmr->sg[i]); + u64 dma_addr = sg_dma_address(&ibmr->sg[i]); frmr->sg_byte_len += dma_len; if (dma_addr & ~PAGE_MASK) { diff --git a/net/rds/ib_recv.c b/net/rds/ib_recv.c index 2f16146e4ec9..70559854837e 100644 --- a/net/rds/ib_recv.c +++ b/net/rds/ib_recv.c @@ -346,8 +346,8 @@ static int rds_ib_recv_refill_one(struct rds_connection *conn, sge->length = sizeof(struct rds_header); sge = &recv->r_sge[1]; - sge->addr = ib_sg_dma_address(ic->i_cm_id->device, &recv->r_frag->f_sg); - sge->length = ib_sg_dma_len(ic->i_cm_id->device, &recv->r_frag->f_sg); + sge->addr = sg_dma_address(&recv->r_frag->f_sg); + sge->length = sg_dma_len(&recv->r_frag->f_sg); ret = 0; out: @@ -409,9 +409,7 @@ void rds_ib_recv_refill(struct rds_connection *conn, int prefill, gfp_t gfp) rdsdebug("recv %p ibinc %p page %p addr %lu\n", recv, recv->r_ibinc, sg_page(&recv->r_frag->f_sg), - (long) ib_sg_dma_address( - ic->i_cm_id->device, - &recv->r_frag->f_sg)); + (long)sg_dma_address(&recv->r_frag->f_sg)); /* XXX when can this fail? */ ret = ib_post_recv(ic->i_cm_id->qp, &recv->r_wr, NULL); @@ -986,9 +984,9 @@ void rds_ib_recv_cqe_handler(struct rds_ib_connection *ic, } else { /* We expect errors as the qp is drained during shutdown */ if (rds_conn_up(conn) || rds_conn_connecting(conn)) - rds_ib_conn_error(conn, "recv completion on <%pI6c,%pI6c> had status %u (%s), disconnecting and reconnecting\n", + rds_ib_conn_error(conn, "recv completion on <%pI6c,%pI6c, %d> had status %u (%s), disconnecting and reconnecting\n", &conn->c_laddr, &conn->c_faddr, - wc->status, + conn->c_tos, wc->status, ib_wc_status_msg(wc->status)); } diff --git a/net/rds/ib_send.c b/net/rds/ib_send.c index 4e0c36acf866..18f2341202f8 100644 --- a/net/rds/ib_send.c +++ b/net/rds/ib_send.c @@ -305,8 +305,9 @@ void rds_ib_send_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc) /* We expect errors as the qp is drained during shutdown */ if (wc->status != IB_WC_SUCCESS && rds_conn_up(conn)) { - rds_ib_conn_error(conn, "send completion on <%pI6c,%pI6c> had status %u (%s), disconnecting and reconnecting\n", - &conn->c_laddr, &conn->c_faddr, wc->status, + rds_ib_conn_error(conn, "send completion on <%pI6c,%pI6c,%d> had status %u (%s), disconnecting and reconnecting\n", + &conn->c_laddr, &conn->c_faddr, + conn->c_tos, wc->status, ib_wc_status_msg(wc->status)); } } @@ -645,16 +646,16 @@ int rds_ib_xmit(struct rds_connection *conn, struct rds_message *rm, if (i < work_alloc && scat != &rm->data.op_sg[rm->data.op_count]) { len = min(RDS_FRAG_SIZE, - ib_sg_dma_len(dev, scat) - rm->data.op_dmaoff); + sg_dma_len(scat) - rm->data.op_dmaoff); send->s_wr.num_sge = 2; - send->s_sge[1].addr = ib_sg_dma_address(dev, scat); + send->s_sge[1].addr = sg_dma_address(scat); send->s_sge[1].addr += rm->data.op_dmaoff; send->s_sge[1].length = len; bytes_sent += len; rm->data.op_dmaoff += len; - if (rm->data.op_dmaoff == ib_sg_dma_len(dev, scat)) { + if (rm->data.op_dmaoff == sg_dma_len(scat)) { scat++; rm->data.op_dmasg++; rm->data.op_dmaoff = 0; @@ -808,8 +809,8 @@ int rds_ib_xmit_atomic(struct rds_connection *conn, struct rm_atomic_op *op) } /* Convert our struct scatterlist to struct ib_sge */ - send->s_sge[0].addr = ib_sg_dma_address(ic->i_cm_id->device, op->op_sg); - send->s_sge[0].length = ib_sg_dma_len(ic->i_cm_id->device, op->op_sg); + send->s_sge[0].addr = sg_dma_address(op->op_sg); + send->s_sge[0].length = sg_dma_len(op->op_sg); send->s_sge[0].lkey = ic->i_pd->local_dma_lkey; rdsdebug("rva %Lx rpa %Lx len %u\n", op->op_remote_addr, @@ -921,9 +922,8 @@ int rds_ib_xmit_rdma(struct rds_connection *conn, struct rm_rdma_op *op) for (j = 0; j < send->s_rdma_wr.wr.num_sge && scat != &op->op_sg[op->op_count]; j++) { - len = ib_sg_dma_len(ic->i_cm_id->device, scat); - send->s_sge[j].addr = - ib_sg_dma_address(ic->i_cm_id->device, scat); + len = sg_dma_len(scat); + send->s_sge[j].addr = sg_dma_address(scat); send->s_sge[j].length = len; send->s_sge[j].lkey = ic->i_pd->local_dma_lkey; diff --git a/net/rds/rdma_transport.c b/net/rds/rdma_transport.c index 6b0f57c83a2a..46bce8389066 100644 --- a/net/rds/rdma_transport.c +++ b/net/rds/rdma_transport.c @@ -51,6 +51,8 @@ static int rds_rdma_cm_event_handler_cmn(struct rdma_cm_id *cm_id, struct rds_connection *conn = cm_id->context; struct rds_transport *trans; int ret = 0; + int *err; + u8 len; rdsdebug("conn %p id %p handling event %u (%s)\n", conn, cm_id, event->event, rdma_event_msg(event->event)); @@ -81,6 +83,7 @@ static int rds_rdma_cm_event_handler_cmn(struct rdma_cm_id *cm_id, break; case RDMA_CM_EVENT_ADDR_RESOLVED: + rdma_set_service_type(cm_id, conn->c_tos); /* XXX do we need to clean up if this fails? */ ret = rdma_resolve_route(cm_id, RDS_RDMA_RESOLVE_TIMEOUT_MS); @@ -106,8 +109,19 @@ static int rds_rdma_cm_event_handler_cmn(struct rdma_cm_id *cm_id, break; case RDMA_CM_EVENT_REJECTED: + if (!conn) + break; + err = (int *)rdma_consumer_reject_data(cm_id, event, &len); + if (!err || (err && ((*err) == RDS_RDMA_REJ_INCOMPAT))) { + pr_warn("RDS/RDMA: conn <%pI6c, %pI6c> rejected, dropping connection\n", + &conn->c_laddr, &conn->c_faddr); + conn->c_proposed_version = RDS_PROTOCOL_COMPAT_VERSION; + conn->c_tos = 0; + rds_conn_drop(conn); + } rdsdebug("Connection rejected: %s\n", rdma_reject_msg(cm_id, event->status)); + break; /* FALLTHROUGH */ case RDMA_CM_EVENT_ADDR_ERROR: case RDMA_CM_EVENT_ROUTE_ERROR: diff --git a/net/rds/rdma_transport.h b/net/rds/rdma_transport.h index 200d3134aaae..bfafd4a6d827 100644 --- a/net/rds/rdma_transport.h +++ b/net/rds/rdma_transport.h @@ -11,6 +11,12 @@ #define RDS_RDMA_RESOLVE_TIMEOUT_MS 5000 +/* Below reject reason is for legacy interoperability issue with non-linux + * RDS endpoints where older version incompatibility is conveyed via value 1. + * For future version(s), proper encoded reject reason should be be used. + */ +#define RDS_RDMA_REJ_INCOMPAT 1 + int rds_rdma_conn_connect(struct rds_connection *conn); int rds_rdma_cm_event_handler(struct rdma_cm_id *cm_id, struct rdma_cm_event *event); diff --git a/net/rds/rds.h b/net/rds/rds.h index 4ffe100ff5e6..0d8f67cadd74 100644 --- a/net/rds/rds.h +++ b/net/rds/rds.h @@ -19,10 +19,13 @@ */ #define RDS_PROTOCOL_3_0 0x0300 #define RDS_PROTOCOL_3_1 0x0301 +#define RDS_PROTOCOL_4_0 0x0400 +#define RDS_PROTOCOL_4_1 0x0401 #define RDS_PROTOCOL_VERSION RDS_PROTOCOL_3_1 #define RDS_PROTOCOL_MAJOR(v) ((v) >> 8) #define RDS_PROTOCOL_MINOR(v) ((v) & 255) #define RDS_PROTOCOL(maj, min) (((maj) << 8) | min) +#define RDS_PROTOCOL_COMPAT_VERSION RDS_PROTOCOL_3_1 /* The following ports, 16385, 18634, 18635, are registered with IANA as * the ports to be used for RDS over TCP and UDP. Currently, only RDS over @@ -151,9 +154,13 @@ struct rds_connection { struct rds_cong_map *c_fcong; /* Protocol version */ + unsigned int c_proposed_version; unsigned int c_version; possible_net_t c_net; + /* TOS */ + u8 c_tos; + struct list_head c_map_item; unsigned long c_map_queued; @@ -567,6 +574,7 @@ struct rds_transport { void (*free_mr)(void *trans_private, int invalidate); void (*flush_mrs)(void); bool (*t_unloading)(struct rds_connection *conn); + u8 (*get_tos_map)(u8 tos); }; /* Bind hash table key length. It is the sum of the size of a struct @@ -648,6 +656,7 @@ struct rds_sock { u8 rs_rx_traces; u8 rs_rx_trace[RDS_MSG_RX_DGRAM_TRACE_MAX]; struct rds_msg_zcopy_queue rs_zcookie_queue; + u8 rs_tos; }; static inline struct rds_sock *rds_sk_to_rs(const struct sock *sk) @@ -756,13 +765,14 @@ void rds_conn_exit(void); struct rds_connection *rds_conn_create(struct net *net, const struct in6_addr *laddr, const struct in6_addr *faddr, - struct rds_transport *trans, gfp_t gfp, + struct rds_transport *trans, + u8 tos, gfp_t gfp, int dev_if); struct rds_connection *rds_conn_create_outgoing(struct net *net, const struct in6_addr *laddr, const struct in6_addr *faddr, struct rds_transport *trans, - gfp_t gfp, int dev_if); + u8 tos, gfp_t gfp, int dev_if); void rds_conn_shutdown(struct rds_conn_path *cpath); void rds_conn_destroy(struct rds_connection *conn); void rds_conn_drop(struct rds_connection *conn); diff --git a/net/rds/recv.c b/net/rds/recv.c index 727639dac8a7..853de4876088 100644 --- a/net/rds/recv.c +++ b/net/rds/recv.c @@ -549,9 +549,21 @@ static int rds_cmsg_recv(struct rds_incoming *inc, struct msghdr *msg, if ((inc->i_rx_tstamp != 0) && sock_flag(rds_rs_to_sk(rs), SOCK_RCVTSTAMP)) { - struct timeval tv = ktime_to_timeval(inc->i_rx_tstamp); - ret = put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMP, - sizeof(tv), &tv); + struct __kernel_old_timeval tv = ns_to_kernel_old_timeval(inc->i_rx_tstamp); + + if (!sock_flag(rds_rs_to_sk(rs), SOCK_TSTAMP_NEW)) { + ret = put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_OLD, + sizeof(tv), &tv); + } else { + struct __kernel_sock_timeval sk_tv; + + sk_tv.tv_sec = tv.tv_sec; + sk_tv.tv_usec = tv.tv_usec; + + ret = put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_NEW, + sizeof(sk_tv), &sk_tv); + } + if (ret) goto out; } @@ -770,6 +782,7 @@ void rds_inc_info_copy(struct rds_incoming *inc, minfo.seq = be64_to_cpu(inc->i_hdr.h_sequence); minfo.len = be32_to_cpu(inc->i_hdr.h_len); + minfo.tos = inc->i_conn->c_tos; if (flip) { minfo.laddr = daddr; diff --git a/net/rds/send.c b/net/rds/send.c index fd8b687d5c05..166dd578c1cc 100644 --- a/net/rds/send.c +++ b/net/rds/send.c @@ -1277,12 +1277,13 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) /* rds_conn_create has a spinlock that runs with IRQ off. * Caching the conn in the socket helps a lot. */ - if (rs->rs_conn && ipv6_addr_equal(&rs->rs_conn->c_faddr, &daddr)) + if (rs->rs_conn && ipv6_addr_equal(&rs->rs_conn->c_faddr, &daddr) && + rs->rs_tos == rs->rs_conn->c_tos) { conn = rs->rs_conn; - else { + } else { conn = rds_conn_create_outgoing(sock_net(sock->sk), &rs->rs_bound_addr, &daddr, - rs->rs_transport, + rs->rs_transport, rs->rs_tos, sock->sk->sk_allocation, scope_id); if (IS_ERR(conn)) { diff --git a/net/rds/tcp.c b/net/rds/tcp.c index c16f0a362c32..fd2694174607 100644 --- a/net/rds/tcp.c +++ b/net/rds/tcp.c @@ -267,6 +267,7 @@ static void rds_tcp_tc_info(struct socket *rds_sock, unsigned int len, tsinfo.last_sent_nxt = tc->t_last_sent_nxt; tsinfo.last_expected_una = tc->t_last_expected_una; tsinfo.last_seen_una = tc->t_last_seen_una; + tsinfo.tos = tc->t_cpath->cp_conn->c_tos; rds_info_copy(iter, &tsinfo, sizeof(tsinfo)); } @@ -452,6 +453,12 @@ static void rds_tcp_destroy_conns(void) static void rds_tcp_exit(void); +static u8 rds_tcp_get_tos_map(u8 tos) +{ + /* all user tos mapped to default 0 for TCP transport */ + return 0; +} + struct rds_transport rds_tcp_transport = { .laddr_check = rds_tcp_laddr_check, .xmit_path_prepare = rds_tcp_xmit_path_prepare, @@ -466,6 +473,7 @@ struct rds_transport rds_tcp_transport = { .inc_free = rds_tcp_inc_free, .stats_info_copy = rds_tcp_stats_info_copy, .exit = rds_tcp_exit, + .get_tos_map = rds_tcp_get_tos_map, .t_owner = THIS_MODULE, .t_name = "tcp", .t_type = RDS_TRANS_TCP, diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c index c12203f646da..810a3a49e947 100644 --- a/net/rds/tcp_listen.c +++ b/net/rds/tcp_listen.c @@ -200,7 +200,7 @@ int rds_tcp_accept_one(struct socket *sock) conn = rds_conn_create(sock_net(sock->sk), my_addr, peer_addr, - &rds_tcp_transport, GFP_KERNEL, dev_if); + &rds_tcp_transport, 0, GFP_KERNEL, dev_if); if (IS_ERR(conn)) { ret = PTR_ERR(conn); diff --git a/net/rds/threads.c b/net/rds/threads.c index e64f9e4c3cda..32dc50f0a303 100644 --- a/net/rds/threads.c +++ b/net/rds/threads.c @@ -93,6 +93,7 @@ void rds_connect_path_complete(struct rds_conn_path *cp, int curr) queue_delayed_work(rds_wq, &cp->cp_recv_w, 0); } rcu_read_unlock(); + cp->cp_conn->c_proposed_version = RDS_PROTOCOL_VERSION; } EXPORT_SYMBOL_GPL(rds_connect_path_complete); diff --git a/net/rose/af_rose.c b/net/rose/af_rose.c index d00a0ef39a56..c96f63ffe31e 100644 --- a/net/rose/af_rose.c +++ b/net/rose/af_rose.c @@ -689,8 +689,10 @@ static int rose_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) rose->source_call = user->call; ax25_uid_put(user); } else { - if (ax25_uid_policy && !capable(CAP_NET_BIND_SERVICE)) + if (ax25_uid_policy && !capable(CAP_NET_BIND_SERVICE)) { + dev_put(dev); return -EACCES; + } rose->source_call = *source; } diff --git a/net/rose/rose_route.c b/net/rose/rose_route.c index 77e9f85a2c92..f2ff21d7df08 100644 --- a/net/rose/rose_route.c +++ b/net/rose/rose_route.c @@ -850,6 +850,7 @@ void rose_link_device_down(struct net_device *dev) /* * Route a frame to an appropriate AX.25 connection. + * A NULL ax25_cb indicates an internally generated frame. */ int rose_route_frame(struct sk_buff *skb, ax25_cb *ax25) { @@ -867,6 +868,10 @@ int rose_route_frame(struct sk_buff *skb, ax25_cb *ax25) if (skb->len < ROSE_MIN_LEN) return res; + + if (!ax25) + return rose_loopback_queue(skb, NULL); + frametype = skb->data[2]; lci = ((skb->data[0] << 8) & 0xF00) + ((skb->data[1] << 0) & 0x0FF); if (frametype == ROSE_CALL_REQUEST && diff --git a/net/rxrpc/conn_client.c b/net/rxrpc/conn_client.c index b2adfa825363..83797b3949e2 100644 --- a/net/rxrpc/conn_client.c +++ b/net/rxrpc/conn_client.c @@ -353,7 +353,7 @@ static int rxrpc_get_client_conn(struct rxrpc_sock *rx, * normally have to take channel_lock but we do this before anyone else * can see the connection. */ - list_add_tail(&call->chan_wait_link, &candidate->waiting_calls); + list_add(&call->chan_wait_link, &candidate->waiting_calls); if (cp->exclusive) { call->conn = candidate; @@ -432,7 +432,7 @@ found_extant_conn: call->conn = conn; call->security_ix = conn->security_ix; call->service_id = conn->service_id; - list_add(&call->chan_wait_link, &conn->waiting_calls); + list_add_tail(&call->chan_wait_link, &conn->waiting_calls); spin_unlock(&conn->channel_lock); _leave(" = 0 [extant %d]", conn->debug_id); return 0; @@ -704,6 +704,7 @@ int rxrpc_connect_call(struct rxrpc_sock *rx, ret = rxrpc_wait_for_channel(call, gfp); if (ret < 0) { + trace_rxrpc_client(call->conn, ret, rxrpc_client_chan_wait_failed); rxrpc_disconnect_client_call(call); goto out; } @@ -774,16 +775,22 @@ static void rxrpc_set_client_reap_timer(struct rxrpc_net *rxnet) */ void rxrpc_disconnect_client_call(struct rxrpc_call *call) { - unsigned int channel = call->cid & RXRPC_CHANNELMASK; struct rxrpc_connection *conn = call->conn; - struct rxrpc_channel *chan = &conn->channels[channel]; + struct rxrpc_channel *chan = NULL; struct rxrpc_net *rxnet = conn->params.local->rxnet; + unsigned int channel = -1; + u32 cid; + spin_lock(&conn->channel_lock); + + cid = call->cid; + if (cid) { + channel = cid & RXRPC_CHANNELMASK; + chan = &conn->channels[channel]; + } trace_rxrpc_client(conn, channel, rxrpc_client_chan_disconnect); call->conn = NULL; - spin_lock(&conn->channel_lock); - /* Calls that have never actually been assigned a channel can simply be * discarded. If the conn didn't get used either, it will follow * immediately unless someone else grabs it in the meantime. @@ -807,7 +814,10 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call) goto out; } - ASSERTCMP(rcu_access_pointer(chan->call), ==, call); + if (rcu_access_pointer(chan->call) != call) { + spin_unlock(&conn->channel_lock); + BUG(); + } /* If a client call was exposed to the world, we save the result for * retransmission. diff --git a/net/rxrpc/local_object.c b/net/rxrpc/local_object.c index 0906e51d3cfb..15cf42d5b53a 100644 --- a/net/rxrpc/local_object.c +++ b/net/rxrpc/local_object.c @@ -202,7 +202,7 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) /* We want receive timestamps. */ opt = 1; - ret = kernel_setsockopt(local->socket, SOL_SOCKET, SO_TIMESTAMPNS, + ret = kernel_setsockopt(local->socket, SOL_SOCKET, SO_TIMESTAMPNS_OLD, (char *)&opt, sizeof(opt)); if (ret < 0) { _debug("setsockopt failed"); diff --git a/net/rxrpc/recvmsg.c b/net/rxrpc/recvmsg.c index eaf19ebaa964..3f7bb11f3290 100644 --- a/net/rxrpc/recvmsg.c +++ b/net/rxrpc/recvmsg.c @@ -596,6 +596,7 @@ error_requeue_call: } error_no_call: release_sock(&rx->sk); +error_trace: trace_rxrpc_recvmsg(call, rxrpc_recvmsg_return, 0, 0, 0, ret); return ret; @@ -604,7 +605,7 @@ wait_interrupted: wait_error: finish_wait(sk_sleep(&rx->sk), &wait); call = NULL; - goto error_no_call; + goto error_trace; } /** diff --git a/net/sched/act_api.c b/net/sched/act_api.c index d4b8355737d8..aecf1bf233c8 100644 --- a/net/sched/act_api.c +++ b/net/sched/act_api.c @@ -543,7 +543,7 @@ int tcf_register_action(struct tc_action_ops *act, write_lock(&act_mod_lock); list_for_each_entry(a, &act_base, head) { - if (act->type == a->type || (strcmp(act->kind, a->kind) == 0)) { + if (act->id == a->id || (strcmp(act->kind, a->kind) == 0)) { write_unlock(&act_mod_lock); unregister_pernet_subsys(ops); return -EEXIST; diff --git a/net/sched/act_bpf.c b/net/sched/act_bpf.c index c7633843e223..aa5c38d11a30 100644 --- a/net/sched/act_bpf.c +++ b/net/sched/act_bpf.c @@ -396,7 +396,7 @@ static int tcf_bpf_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_bpf_ops __read_mostly = { .kind = "bpf", - .type = TCA_ACT_BPF, + .id = TCA_ID_BPF, .owner = THIS_MODULE, .act = tcf_bpf_act, .dump = tcf_bpf_dump, diff --git a/net/sched/act_connmark.c b/net/sched/act_connmark.c index 8475913f2070..5d24993cccfe 100644 --- a/net/sched/act_connmark.c +++ b/net/sched/act_connmark.c @@ -204,7 +204,7 @@ static int tcf_connmark_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_connmark_ops = { .kind = "connmark", - .type = TCA_ACT_CONNMARK, + .id = TCA_ID_CONNMARK, .owner = THIS_MODULE, .act = tcf_connmark_act, .dump = tcf_connmark_dump, diff --git a/net/sched/act_csum.c b/net/sched/act_csum.c index 3dc25b7806d7..c79aca29505e 100644 --- a/net/sched/act_csum.c +++ b/net/sched/act_csum.c @@ -559,8 +559,11 @@ static int tcf_csum_act(struct sk_buff *skb, const struct tc_action *a, struct tcf_result *res) { struct tcf_csum *p = to_tcf_csum(a); + bool orig_vlan_tag_present = false; + unsigned int vlan_hdr_count = 0; struct tcf_csum_params *params; u32 update_flags; + __be16 protocol; int action; params = rcu_dereference_bh(p->params); @@ -573,7 +576,9 @@ static int tcf_csum_act(struct sk_buff *skb, const struct tc_action *a, goto drop; update_flags = params->update_flags; - switch (tc_skb_protocol(skb)) { + protocol = tc_skb_protocol(skb); +again: + switch (protocol) { case cpu_to_be16(ETH_P_IP): if (!tcf_csum_ipv4(skb, update_flags)) goto drop; @@ -582,13 +587,35 @@ static int tcf_csum_act(struct sk_buff *skb, const struct tc_action *a, if (!tcf_csum_ipv6(skb, update_flags)) goto drop; break; + case cpu_to_be16(ETH_P_8021AD): /* fall through */ + case cpu_to_be16(ETH_P_8021Q): + if (skb_vlan_tag_present(skb) && !orig_vlan_tag_present) { + protocol = skb->protocol; + orig_vlan_tag_present = true; + } else { + struct vlan_hdr *vlan = (struct vlan_hdr *)skb->data; + + protocol = vlan->h_vlan_encapsulated_proto; + skb_pull(skb, VLAN_HLEN); + skb_reset_network_header(skb); + vlan_hdr_count++; + } + goto again; + } + +out: + /* Restore the skb for the pulled VLAN tags */ + while (vlan_hdr_count--) { + skb_push(skb, VLAN_HLEN); + skb_reset_network_header(skb); } return action; drop: qstats_drop_inc(this_cpu_ptr(p->common.cpu_qstats)); - return TC_ACT_SHOT; + action = TC_ACT_SHOT; + goto out; } static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind, @@ -660,7 +687,7 @@ static size_t tcf_csum_get_fill_size(const struct tc_action *act) static struct tc_action_ops act_csum_ops = { .kind = "csum", - .type = TCA_ACT_CSUM, + .id = TCA_ID_CSUM, .owner = THIS_MODULE, .act = tcf_csum_act, .dump = tcf_csum_dump, diff --git a/net/sched/act_gact.c b/net/sched/act_gact.c index b61c20ebb314..93da0004e9f4 100644 --- a/net/sched/act_gact.c +++ b/net/sched/act_gact.c @@ -253,7 +253,7 @@ static size_t tcf_gact_get_fill_size(const struct tc_action *act) static struct tc_action_ops act_gact_ops = { .kind = "gact", - .type = TCA_ACT_GACT, + .id = TCA_ID_GACT, .owner = THIS_MODULE, .act = tcf_gact_act, .stats_update = tcf_gact_stats_update, diff --git a/net/sched/act_ife.c b/net/sched/act_ife.c index 30b63fa23ee2..9b1f2b3990ee 100644 --- a/net/sched/act_ife.c +++ b/net/sched/act_ife.c @@ -864,7 +864,7 @@ static int tcf_ife_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_ife_ops = { .kind = "ife", - .type = TCA_ACT_IFE, + .id = TCA_ID_IFE, .owner = THIS_MODULE, .act = tcf_ife_act, .dump = tcf_ife_dump, diff --git a/net/sched/act_ipt.c b/net/sched/act_ipt.c index 8af6c11d2482..98f5b6ea77b4 100644 --- a/net/sched/act_ipt.c +++ b/net/sched/act_ipt.c @@ -199,8 +199,7 @@ err3: err2: kfree(tname); err1: - if (ret == ACT_P_CREATED) - tcf_idr_release(*a, bind); + tcf_idr_release(*a, bind); return err; } @@ -338,7 +337,7 @@ static int tcf_ipt_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_ipt_ops = { .kind = "ipt", - .type = TCA_ACT_IPT, + .id = TCA_ID_IPT, .owner = THIS_MODULE, .act = tcf_ipt_act, .dump = tcf_ipt_dump, @@ -387,7 +386,7 @@ static int tcf_xt_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_xt_ops = { .kind = "xt", - .type = TCA_ACT_XT, + .id = TCA_ID_XT, .owner = THIS_MODULE, .act = tcf_ipt_act, .dump = tcf_ipt_dump, diff --git a/net/sched/act_mirred.c b/net/sched/act_mirred.c index c8cf4d10c435..6692fd054617 100644 --- a/net/sched/act_mirred.c +++ b/net/sched/act_mirred.c @@ -400,7 +400,7 @@ static void tcf_mirred_put_dev(struct net_device *dev) static struct tc_action_ops act_mirred_ops = { .kind = "mirred", - .type = TCA_ACT_MIRRED, + .id = TCA_ID_MIRRED, .owner = THIS_MODULE, .act = tcf_mirred_act, .stats_update = tcf_stats_update, diff --git a/net/sched/act_nat.c b/net/sched/act_nat.c index c5c1e23add77..543eab9193f1 100644 --- a/net/sched/act_nat.c +++ b/net/sched/act_nat.c @@ -304,7 +304,7 @@ static int tcf_nat_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_nat_ops = { .kind = "nat", - .type = TCA_ACT_NAT, + .id = TCA_ID_NAT, .owner = THIS_MODULE, .act = tcf_nat_act, .dump = tcf_nat_dump, diff --git a/net/sched/act_pedit.c b/net/sched/act_pedit.c index 2b372a06b432..a80373878df7 100644 --- a/net/sched/act_pedit.c +++ b/net/sched/act_pedit.c @@ -406,7 +406,7 @@ static int tcf_pedit_dump(struct sk_buff *skb, struct tc_action *a, struct tcf_t t; int s; - s = sizeof(*opt) + p->tcfp_nkeys * sizeof(struct tc_pedit_key); + s = struct_size(opt, keys, p->tcfp_nkeys); /* netlink spinlocks held above us - must use ATOMIC */ opt = kzalloc(s, GFP_ATOMIC); @@ -470,7 +470,7 @@ static int tcf_pedit_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_pedit_ops = { .kind = "pedit", - .type = TCA_ACT_PEDIT, + .id = TCA_ID_PEDIT, .owner = THIS_MODULE, .act = tcf_pedit_act, .dump = tcf_pedit_dump, diff --git a/net/sched/act_police.c b/net/sched/act_police.c index ec8ec55e0fe8..8271a6263824 100644 --- a/net/sched/act_police.c +++ b/net/sched/act_police.c @@ -366,7 +366,7 @@ MODULE_LICENSE("GPL"); static struct tc_action_ops act_police_ops = { .kind = "police", - .type = TCA_ID_POLICE, + .id = TCA_ID_POLICE, .owner = THIS_MODULE, .act = tcf_police_act, .dump = tcf_police_dump, diff --git a/net/sched/act_sample.c b/net/sched/act_sample.c index 1a0c682fd734..203e399e5c85 100644 --- a/net/sched/act_sample.c +++ b/net/sched/act_sample.c @@ -233,7 +233,7 @@ static int tcf_sample_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_sample_ops = { .kind = "sample", - .type = TCA_ACT_SAMPLE, + .id = TCA_ID_SAMPLE, .owner = THIS_MODULE, .act = tcf_sample_act, .dump = tcf_sample_dump, diff --git a/net/sched/act_simple.c b/net/sched/act_simple.c index 902957beceb3..d54cb608dbaf 100644 --- a/net/sched/act_simple.c +++ b/net/sched/act_simple.c @@ -19,8 +19,6 @@ #include <net/netlink.h> #include <net/pkt_sched.h> -#define TCA_ACT_SIMP 22 - #include <linux/tc_act/tc_defact.h> #include <net/tc_act/tc_defact.h> @@ -197,7 +195,7 @@ static int tcf_simp_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_simp_ops = { .kind = "simple", - .type = TCA_ACT_SIMP, + .id = TCA_ID_SIMP, .owner = THIS_MODULE, .act = tcf_simp_act, .dump = tcf_simp_dump, diff --git a/net/sched/act_skbedit.c b/net/sched/act_skbedit.c index 64dba3708fce..65879500b688 100644 --- a/net/sched/act_skbedit.c +++ b/net/sched/act_skbedit.c @@ -189,8 +189,7 @@ static int tcf_skbedit_init(struct net *net, struct nlattr *nla, params_new = kzalloc(sizeof(*params_new), GFP_KERNEL); if (unlikely(!params_new)) { - if (ret == ACT_P_CREATED) - tcf_idr_release(*a, bind); + tcf_idr_release(*a, bind); return -ENOMEM; } @@ -305,7 +304,7 @@ static int tcf_skbedit_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_skbedit_ops = { .kind = "skbedit", - .type = TCA_ACT_SKBEDIT, + .id = TCA_ID_SKBEDIT, .owner = THIS_MODULE, .act = tcf_skbedit_act, .dump = tcf_skbedit_dump, diff --git a/net/sched/act_skbmod.c b/net/sched/act_skbmod.c index 59710a183bd3..7bac1d78e7a3 100644 --- a/net/sched/act_skbmod.c +++ b/net/sched/act_skbmod.c @@ -260,7 +260,7 @@ static int tcf_skbmod_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_skbmod_ops = { .kind = "skbmod", - .type = TCA_ACT_SKBMOD, + .id = TCA_ACT_SKBMOD, .owner = THIS_MODULE, .act = tcf_skbmod_act, .dump = tcf_skbmod_dump, diff --git a/net/sched/act_tunnel_key.c b/net/sched/act_tunnel_key.c index 8b43fe0130f7..7c6591b991d5 100644 --- a/net/sched/act_tunnel_key.c +++ b/net/sched/act_tunnel_key.c @@ -203,6 +203,7 @@ static void tunnel_key_release_params(struct tcf_tunnel_key_params *p) return; if (p->tcft_action == TCA_TUNNEL_KEY_ACT_SET) dst_release(&p->tcft_enc_metadata->dst); + kfree_rcu(p, rcu); } @@ -321,6 +322,12 @@ static int tunnel_key_init(struct net *net, struct nlattr *nla, goto err_out; } +#ifdef CONFIG_DST_CACHE + ret = dst_cache_init(&metadata->u.tun_info.dst_cache, GFP_KERNEL); + if (ret) + goto release_tun_meta; +#endif + if (opts_len) { ret = tunnel_key_opts_set(tb[TCA_TUNNEL_KEY_ENC_OPTS], &metadata->u.tun_info, @@ -377,7 +384,8 @@ static int tunnel_key_init(struct net *net, struct nlattr *nla, return ret; release_tun_meta: - dst_release(&metadata->dst); + if (metadata) + dst_release(&metadata->dst); err_out: if (exists) @@ -563,7 +571,7 @@ static int tunnel_key_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_tunnel_key_ops = { .kind = "tunnel_key", - .type = TCA_ACT_TUNNEL_KEY, + .id = TCA_ID_TUNNEL_KEY, .owner = THIS_MODULE, .act = tunnel_key_act, .dump = tunnel_key_dump, diff --git a/net/sched/act_vlan.c b/net/sched/act_vlan.c index 93fdaf707313..ac0061599225 100644 --- a/net/sched/act_vlan.c +++ b/net/sched/act_vlan.c @@ -297,7 +297,7 @@ static int tcf_vlan_search(struct net *net, struct tc_action **a, u32 index) static struct tc_action_ops act_vlan_ops = { .kind = "vlan", - .type = TCA_ACT_VLAN, + .id = TCA_ID_VLAN, .owner = THIS_MODULE, .act = tcf_vlan_act, .dump = tcf_vlan_dump, diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index e2b5cb2eb34e..dc10525e90e7 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -31,6 +31,13 @@ #include <net/netlink.h> #include <net/pkt_sched.h> #include <net/pkt_cls.h> +#include <net/tc_act/tc_pedit.h> +#include <net/tc_act/tc_mirred.h> +#include <net/tc_act/tc_vlan.h> +#include <net/tc_act/tc_tunnel_key.h> +#include <net/tc_act/tc_csum.h> +#include <net/tc_act/tc_gact.h> +#include <net/tc_act/tc_skbedit.h> extern const struct nla_policy rtm_tca_policy[TCA_MAX + 1]; @@ -61,7 +68,8 @@ static const struct tcf_proto_ops *__tcf_proto_lookup_ops(const char *kind) } static const struct tcf_proto_ops * -tcf_proto_lookup_ops(const char *kind, struct netlink_ext_ack *extack) +tcf_proto_lookup_ops(const char *kind, bool rtnl_held, + struct netlink_ext_ack *extack) { const struct tcf_proto_ops *ops; @@ -69,9 +77,11 @@ tcf_proto_lookup_ops(const char *kind, struct netlink_ext_ack *extack) if (ops) return ops; #ifdef CONFIG_MODULES - rtnl_unlock(); + if (rtnl_held) + rtnl_unlock(); request_module("cls_%s", kind); - rtnl_lock(); + if (rtnl_held) + rtnl_lock(); ops = __tcf_proto_lookup_ops(kind); /* We dropped the RTNL semaphore in order to perform * the module load. So, even if we succeeded in loading @@ -152,8 +162,26 @@ static inline u32 tcf_auto_prio(struct tcf_proto *tp) return TC_H_MAJ(first); } +static bool tcf_proto_is_unlocked(const char *kind) +{ + const struct tcf_proto_ops *ops; + bool ret; + + ops = tcf_proto_lookup_ops(kind, false, NULL); + /* On error return false to take rtnl lock. Proto lookup/create + * functions will perform lookup again and properly handle errors. + */ + if (IS_ERR(ops)) + return false; + + ret = !!(ops->flags & TCF_PROTO_OPS_DOIT_UNLOCKED); + module_put(ops->owner); + return ret; +} + static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol, u32 prio, struct tcf_chain *chain, + bool rtnl_held, struct netlink_ext_ack *extack) { struct tcf_proto *tp; @@ -163,7 +191,7 @@ static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol, if (!tp) return ERR_PTR(-ENOBUFS); - tp->ops = tcf_proto_lookup_ops(kind, extack); + tp->ops = tcf_proto_lookup_ops(kind, rtnl_held, extack); if (IS_ERR(tp->ops)) { err = PTR_ERR(tp->ops); goto errout; @@ -172,6 +200,8 @@ static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol, tp->protocol = protocol; tp->prio = prio; tp->chain = chain; + spin_lock_init(&tp->lock); + refcount_set(&tp->refcnt, 1); err = tp->ops->init(tp); if (err) { @@ -185,14 +215,80 @@ errout: return ERR_PTR(err); } -static void tcf_proto_destroy(struct tcf_proto *tp, +static void tcf_proto_get(struct tcf_proto *tp) +{ + refcount_inc(&tp->refcnt); +} + +static void tcf_chain_put(struct tcf_chain *chain); + +static void tcf_proto_destroy(struct tcf_proto *tp, bool rtnl_held, struct netlink_ext_ack *extack) { - tp->ops->destroy(tp, extack); + tp->ops->destroy(tp, rtnl_held, extack); + tcf_chain_put(tp->chain); module_put(tp->ops->owner); kfree_rcu(tp, rcu); } +static void tcf_proto_put(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + if (refcount_dec_and_test(&tp->refcnt)) + tcf_proto_destroy(tp, rtnl_held, extack); +} + +static int walker_check_empty(struct tcf_proto *tp, void *fh, + struct tcf_walker *arg) +{ + if (fh) { + arg->nonempty = true; + return -1; + } + return 0; +} + +static bool tcf_proto_is_empty(struct tcf_proto *tp, bool rtnl_held) +{ + struct tcf_walker walker = { .fn = walker_check_empty, }; + + if (tp->ops->walk) { + tp->ops->walk(tp, &walker, rtnl_held); + return !walker.nonempty; + } + return true; +} + +static bool tcf_proto_check_delete(struct tcf_proto *tp, bool rtnl_held) +{ + spin_lock(&tp->lock); + if (tcf_proto_is_empty(tp, rtnl_held)) + tp->deleting = true; + spin_unlock(&tp->lock); + return tp->deleting; +} + +static void tcf_proto_mark_delete(struct tcf_proto *tp) +{ + spin_lock(&tp->lock); + tp->deleting = true; + spin_unlock(&tp->lock); +} + +static bool tcf_proto_is_deleting(struct tcf_proto *tp) +{ + bool deleting; + + spin_lock(&tp->lock); + deleting = tp->deleting; + spin_unlock(&tp->lock); + + return deleting; +} + +#define ASSERT_BLOCK_LOCKED(block) \ + lockdep_assert_held(&(block)->lock) + struct tcf_filter_chain_list_item { struct list_head list; tcf_chain_head_change_t *chain_head_change; @@ -204,10 +300,13 @@ static struct tcf_chain *tcf_chain_create(struct tcf_block *block, { struct tcf_chain *chain; + ASSERT_BLOCK_LOCKED(block); + chain = kzalloc(sizeof(*chain), GFP_KERNEL); if (!chain) return NULL; list_add_tail(&chain->list, &block->chain_list); + mutex_init(&chain->filter_chain_lock); chain->block = block; chain->index = chain_index; chain->refcnt = 1; @@ -231,29 +330,59 @@ static void tcf_chain0_head_change(struct tcf_chain *chain, if (chain->index) return; + + mutex_lock(&block->lock); list_for_each_entry(item, &block->chain0.filter_chain_list, list) tcf_chain_head_change_item(item, tp_head); + mutex_unlock(&block->lock); } -static void tcf_chain_destroy(struct tcf_chain *chain) +/* Returns true if block can be safely freed. */ + +static bool tcf_chain_detach(struct tcf_chain *chain) { struct tcf_block *block = chain->block; + ASSERT_BLOCK_LOCKED(block); + list_del(&chain->list); if (!chain->index) block->chain0.chain = NULL; + + if (list_empty(&block->chain_list) && + refcount_read(&block->refcnt) == 0) + return true; + + return false; +} + +static void tcf_block_destroy(struct tcf_block *block) +{ + mutex_destroy(&block->lock); + kfree_rcu(block, rcu); +} + +static void tcf_chain_destroy(struct tcf_chain *chain, bool free_block) +{ + struct tcf_block *block = chain->block; + + mutex_destroy(&chain->filter_chain_lock); kfree(chain); - if (list_empty(&block->chain_list) && !refcount_read(&block->refcnt)) - kfree_rcu(block, rcu); + if (free_block) + tcf_block_destroy(block); } static void tcf_chain_hold(struct tcf_chain *chain) { + ASSERT_BLOCK_LOCKED(chain->block); + ++chain->refcnt; } static bool tcf_chain_held_by_acts_only(struct tcf_chain *chain) { + ASSERT_BLOCK_LOCKED(chain->block); + /* In case all the references are action references, this * chain should not be shown to the user. */ @@ -265,6 +394,8 @@ static struct tcf_chain *tcf_chain_lookup(struct tcf_block *block, { struct tcf_chain *chain; + ASSERT_BLOCK_LOCKED(block); + list_for_each_entry(chain, &block->chain_list, list) { if (chain->index == chain_index) return chain; @@ -279,31 +410,40 @@ static struct tcf_chain *__tcf_chain_get(struct tcf_block *block, u32 chain_index, bool create, bool by_act) { - struct tcf_chain *chain = tcf_chain_lookup(block, chain_index); + struct tcf_chain *chain = NULL; + bool is_first_reference; + mutex_lock(&block->lock); + chain = tcf_chain_lookup(block, chain_index); if (chain) { tcf_chain_hold(chain); } else { if (!create) - return NULL; + goto errout; chain = tcf_chain_create(block, chain_index); if (!chain) - return NULL; + goto errout; } if (by_act) ++chain->action_refcnt; + is_first_reference = chain->refcnt - chain->action_refcnt == 1; + mutex_unlock(&block->lock); /* Send notification only in case we got the first * non-action reference. Until then, the chain acts only as * a placeholder for actions pointing to it and user ought * not know about them. */ - if (chain->refcnt - chain->action_refcnt == 1 && !by_act) + if (is_first_reference && !by_act) tc_chain_notify(chain, NULL, 0, NLM_F_CREATE | NLM_F_EXCL, RTM_NEWCHAIN, false); return chain; + +errout: + mutex_unlock(&block->lock); + return chain; } static struct tcf_chain *tcf_chain_get(struct tcf_block *block, u32 chain_index, @@ -318,51 +458,91 @@ struct tcf_chain *tcf_chain_get_by_act(struct tcf_block *block, u32 chain_index) } EXPORT_SYMBOL(tcf_chain_get_by_act); -static void tc_chain_tmplt_del(struct tcf_chain *chain); +static void tc_chain_tmplt_del(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv); +static int tc_chain_notify_delete(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv, u32 chain_index, + struct tcf_block *block, struct sk_buff *oskb, + u32 seq, u16 flags, bool unicast); -static void __tcf_chain_put(struct tcf_chain *chain, bool by_act) +static void __tcf_chain_put(struct tcf_chain *chain, bool by_act, + bool explicitly_created) { + struct tcf_block *block = chain->block; + const struct tcf_proto_ops *tmplt_ops; + bool free_block = false; + unsigned int refcnt; + void *tmplt_priv; + + mutex_lock(&block->lock); + if (explicitly_created) { + if (!chain->explicitly_created) { + mutex_unlock(&block->lock); + return; + } + chain->explicitly_created = false; + } + if (by_act) chain->action_refcnt--; - chain->refcnt--; + + /* tc_chain_notify_delete can't be called while holding block lock. + * However, when block is unlocked chain can be changed concurrently, so + * save these to temporary variables. + */ + refcnt = --chain->refcnt; + tmplt_ops = chain->tmplt_ops; + tmplt_priv = chain->tmplt_priv; /* The last dropped non-action reference will trigger notification. */ - if (chain->refcnt - chain->action_refcnt == 0 && !by_act) - tc_chain_notify(chain, NULL, 0, 0, RTM_DELCHAIN, false); + if (refcnt - chain->action_refcnt == 0 && !by_act) { + tc_chain_notify_delete(tmplt_ops, tmplt_priv, chain->index, + block, NULL, 0, 0, false); + /* Last reference to chain, no need to lock. */ + chain->flushing = false; + } - if (chain->refcnt == 0) { - tc_chain_tmplt_del(chain); - tcf_chain_destroy(chain); + if (refcnt == 0) + free_block = tcf_chain_detach(chain); + mutex_unlock(&block->lock); + + if (refcnt == 0) { + tc_chain_tmplt_del(tmplt_ops, tmplt_priv); + tcf_chain_destroy(chain, free_block); } } static void tcf_chain_put(struct tcf_chain *chain) { - __tcf_chain_put(chain, false); + __tcf_chain_put(chain, false, false); } void tcf_chain_put_by_act(struct tcf_chain *chain) { - __tcf_chain_put(chain, true); + __tcf_chain_put(chain, true, false); } EXPORT_SYMBOL(tcf_chain_put_by_act); static void tcf_chain_put_explicitly_created(struct tcf_chain *chain) { - if (chain->explicitly_created) - tcf_chain_put(chain); + __tcf_chain_put(chain, false, true); } -static void tcf_chain_flush(struct tcf_chain *chain) +static void tcf_chain_flush(struct tcf_chain *chain, bool rtnl_held) { - struct tcf_proto *tp = rtnl_dereference(chain->filter_chain); + struct tcf_proto *tp, *tp_next; + mutex_lock(&chain->filter_chain_lock); + tp = tcf_chain_dereference(chain->filter_chain, chain); + RCU_INIT_POINTER(chain->filter_chain, NULL); tcf_chain0_head_change(chain, NULL); + chain->flushing = true; + mutex_unlock(&chain->filter_chain_lock); + while (tp) { - RCU_INIT_POINTER(chain->filter_chain, tp->next); - tcf_proto_destroy(tp, NULL); - tp = rtnl_dereference(chain->filter_chain); - tcf_chain_put(chain); + tp_next = rcu_dereference_protected(tp->next, 1); + tcf_proto_put(tp, rtnl_held, NULL); + tp = tp_next; } } @@ -684,8 +864,8 @@ tcf_chain0_head_change_cb_add(struct tcf_block *block, struct tcf_block_ext_info *ei, struct netlink_ext_ack *extack) { - struct tcf_chain *chain0 = block->chain0.chain; struct tcf_filter_chain_list_item *item; + struct tcf_chain *chain0; item = kmalloc(sizeof(*item), GFP_KERNEL); if (!item) { @@ -694,9 +874,32 @@ tcf_chain0_head_change_cb_add(struct tcf_block *block, } item->chain_head_change = ei->chain_head_change; item->chain_head_change_priv = ei->chain_head_change_priv; - if (chain0 && chain0->filter_chain) - tcf_chain_head_change_item(item, chain0->filter_chain); - list_add(&item->list, &block->chain0.filter_chain_list); + + mutex_lock(&block->lock); + chain0 = block->chain0.chain; + if (chain0) + tcf_chain_hold(chain0); + else + list_add(&item->list, &block->chain0.filter_chain_list); + mutex_unlock(&block->lock); + + if (chain0) { + struct tcf_proto *tp_head; + + mutex_lock(&chain0->filter_chain_lock); + + tp_head = tcf_chain_dereference(chain0->filter_chain, chain0); + if (tp_head) + tcf_chain_head_change_item(item, tp_head); + + mutex_lock(&block->lock); + list_add(&item->list, &block->chain0.filter_chain_list); + mutex_unlock(&block->lock); + + mutex_unlock(&chain0->filter_chain_lock); + tcf_chain_put(chain0); + } + return 0; } @@ -704,20 +907,23 @@ static void tcf_chain0_head_change_cb_del(struct tcf_block *block, struct tcf_block_ext_info *ei) { - struct tcf_chain *chain0 = block->chain0.chain; struct tcf_filter_chain_list_item *item; + mutex_lock(&block->lock); list_for_each_entry(item, &block->chain0.filter_chain_list, list) { if ((!ei->chain_head_change && !ei->chain_head_change_priv) || (item->chain_head_change == ei->chain_head_change && item->chain_head_change_priv == ei->chain_head_change_priv)) { - if (chain0) + if (block->chain0.chain) tcf_chain_head_change_item(item, NULL); list_del(&item->list); + mutex_unlock(&block->lock); + kfree(item); return; } } + mutex_unlock(&block->lock); WARN_ON(1); } @@ -764,6 +970,7 @@ static struct tcf_block *tcf_block_create(struct net *net, struct Qdisc *q, NL_SET_ERR_MSG(extack, "Memory allocation for block failed"); return ERR_PTR(-ENOMEM); } + mutex_init(&block->lock); INIT_LIST_HEAD(&block->chain_list); INIT_LIST_HEAD(&block->cb_list); INIT_LIST_HEAD(&block->owner_list); @@ -799,157 +1006,241 @@ static struct tcf_block *tcf_block_refcnt_get(struct net *net, u32 block_index) return block; } -static void tcf_block_flush_all_chains(struct tcf_block *block) +static struct tcf_chain * +__tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) { - struct tcf_chain *chain; + mutex_lock(&block->lock); + if (chain) + chain = list_is_last(&chain->list, &block->chain_list) ? + NULL : list_next_entry(chain, list); + else + chain = list_first_entry_or_null(&block->chain_list, + struct tcf_chain, list); - /* Hold a refcnt for all chains, so that they don't disappear - * while we are iterating. - */ - list_for_each_entry(chain, &block->chain_list, list) + /* skip all action-only chains */ + while (chain && tcf_chain_held_by_acts_only(chain)) + chain = list_is_last(&chain->list, &block->chain_list) ? + NULL : list_next_entry(chain, list); + + if (chain) tcf_chain_hold(chain); + mutex_unlock(&block->lock); - list_for_each_entry(chain, &block->chain_list, list) - tcf_chain_flush(chain); + return chain; } -static void tcf_block_put_all_chains(struct tcf_block *block) +/* Function to be used by all clients that want to iterate over all chains on + * block. It properly obtains block->lock and takes reference to chain before + * returning it. Users of this function must be tolerant to concurrent chain + * insertion/deletion or ensure that no concurrent chain modification is + * possible. Note that all netlink dump callbacks cannot guarantee to provide + * consistent dump because rtnl lock is released each time skb is filled with + * data and sent to user-space. + */ + +struct tcf_chain * +tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) { - struct tcf_chain *chain, *tmp; + struct tcf_chain *chain_next = __tcf_get_next_chain(block, chain); - /* At this point, all the chains should have refcnt >= 1. */ - list_for_each_entry_safe(chain, tmp, &block->chain_list, list) { - tcf_chain_put_explicitly_created(chain); + if (chain) tcf_chain_put(chain); - } + + return chain_next; } +EXPORT_SYMBOL(tcf_get_next_chain); -static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, - struct tcf_block_ext_info *ei) +static struct tcf_proto * +__tcf_get_next_proto(struct tcf_chain *chain, struct tcf_proto *tp) { - if (refcount_dec_and_test(&block->refcnt)) { - /* Flushing/putting all chains will cause the block to be - * deallocated when last chain is freed. However, if chain_list - * is empty, block has to be manually deallocated. After block - * reference counter reached 0, it is no longer possible to - * increment it or add new chains to block. - */ - bool free_block = list_empty(&block->chain_list); + u32 prio = 0; - if (tcf_block_shared(block)) - tcf_block_remove(block, block->net); - if (!free_block) - tcf_block_flush_all_chains(block); + ASSERT_RTNL(); + mutex_lock(&chain->filter_chain_lock); - if (q) - tcf_block_offload_unbind(block, q, ei); + if (!tp) { + tp = tcf_chain_dereference(chain->filter_chain, chain); + } else if (tcf_proto_is_deleting(tp)) { + /* 'deleting' flag is set and chain->filter_chain_lock was + * unlocked, which means next pointer could be invalid. Restart + * search. + */ + prio = tp->prio + 1; + tp = tcf_chain_dereference(chain->filter_chain, chain); - if (free_block) - kfree_rcu(block, rcu); - else - tcf_block_put_all_chains(block); - } else if (q) { - tcf_block_offload_unbind(block, q, ei); + for (; tp; tp = tcf_chain_dereference(tp->next, chain)) + if (!tp->deleting && tp->prio >= prio) + break; + } else { + tp = tcf_chain_dereference(tp->next, chain); } + + if (tp) + tcf_proto_get(tp); + + mutex_unlock(&chain->filter_chain_lock); + + return tp; } -static void tcf_block_refcnt_put(struct tcf_block *block) +/* Function to be used by all clients that want to iterate over all tp's on + * chain. Users of this function must be tolerant to concurrent tp + * insertion/deletion or ensure that no concurrent chain modification is + * possible. Note that all netlink dump callbacks cannot guarantee to provide + * consistent dump because rtnl lock is released each time skb is filled with + * data and sent to user-space. + */ + +struct tcf_proto * +tcf_get_next_proto(struct tcf_chain *chain, struct tcf_proto *tp, + bool rtnl_held) { - __tcf_block_put(block, NULL, NULL); + struct tcf_proto *tp_next = __tcf_get_next_proto(chain, tp); + + if (tp) + tcf_proto_put(tp, rtnl_held, NULL); + + return tp_next; } +EXPORT_SYMBOL(tcf_get_next_proto); -/* Find tcf block. - * Set q, parent, cl when appropriate. +static void tcf_block_flush_all_chains(struct tcf_block *block, bool rtnl_held) +{ + struct tcf_chain *chain; + + /* Last reference to block. At this point chains cannot be added or + * removed concurrently. + */ + for (chain = tcf_get_next_chain(block, NULL); + chain; + chain = tcf_get_next_chain(block, chain)) { + tcf_chain_put_explicitly_created(chain); + tcf_chain_flush(chain, rtnl_held); + } +} + +/* Lookup Qdisc and increments its reference counter. + * Set parent, if necessary. */ -static struct tcf_block *tcf_block_find(struct net *net, struct Qdisc **q, - u32 *parent, unsigned long *cl, - int ifindex, u32 block_index, - struct netlink_ext_ack *extack) +static int __tcf_qdisc_find(struct net *net, struct Qdisc **q, + u32 *parent, int ifindex, bool rtnl_held, + struct netlink_ext_ack *extack) { - struct tcf_block *block; + const struct Qdisc_class_ops *cops; + struct net_device *dev; int err = 0; - if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) { - block = tcf_block_refcnt_get(net, block_index); - if (!block) { - NL_SET_ERR_MSG(extack, "Block of given index was not found"); - return ERR_PTR(-EINVAL); - } - } else { - const struct Qdisc_class_ops *cops; - struct net_device *dev; - - rcu_read_lock(); + if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) + return 0; - /* Find link */ - dev = dev_get_by_index_rcu(net, ifindex); - if (!dev) { - rcu_read_unlock(); - return ERR_PTR(-ENODEV); - } + rcu_read_lock(); - /* Find qdisc */ - if (!*parent) { - *q = dev->qdisc; - *parent = (*q)->handle; - } else { - *q = qdisc_lookup_rcu(dev, TC_H_MAJ(*parent)); - if (!*q) { - NL_SET_ERR_MSG(extack, "Parent Qdisc doesn't exists"); - err = -EINVAL; - goto errout_rcu; - } - } + /* Find link */ + dev = dev_get_by_index_rcu(net, ifindex); + if (!dev) { + rcu_read_unlock(); + return -ENODEV; + } - *q = qdisc_refcount_inc_nz(*q); + /* Find qdisc */ + if (!*parent) { + *q = dev->qdisc; + *parent = (*q)->handle; + } else { + *q = qdisc_lookup_rcu(dev, TC_H_MAJ(*parent)); if (!*q) { NL_SET_ERR_MSG(extack, "Parent Qdisc doesn't exists"); err = -EINVAL; goto errout_rcu; } + } - /* Is it classful? */ - cops = (*q)->ops->cl_ops; - if (!cops) { - NL_SET_ERR_MSG(extack, "Qdisc not classful"); - err = -EINVAL; - goto errout_rcu; - } + *q = qdisc_refcount_inc_nz(*q); + if (!*q) { + NL_SET_ERR_MSG(extack, "Parent Qdisc doesn't exists"); + err = -EINVAL; + goto errout_rcu; + } - if (!cops->tcf_block) { - NL_SET_ERR_MSG(extack, "Class doesn't support blocks"); - err = -EOPNOTSUPP; - goto errout_rcu; - } + /* Is it classful? */ + cops = (*q)->ops->cl_ops; + if (!cops) { + NL_SET_ERR_MSG(extack, "Qdisc not classful"); + err = -EINVAL; + goto errout_qdisc; + } - /* At this point we know that qdisc is not noop_qdisc, - * which means that qdisc holds a reference to net_device - * and we hold a reference to qdisc, so it is safe to release - * rcu read lock. - */ - rcu_read_unlock(); + if (!cops->tcf_block) { + NL_SET_ERR_MSG(extack, "Class doesn't support blocks"); + err = -EOPNOTSUPP; + goto errout_qdisc; + } - /* Do we search for filter, attached to class? */ - if (TC_H_MIN(*parent)) { - *cl = cops->find(*q, *parent); - if (*cl == 0) { - NL_SET_ERR_MSG(extack, "Specified class doesn't exist"); - err = -ENOENT; - goto errout_qdisc; - } +errout_rcu: + /* At this point we know that qdisc is not noop_qdisc, + * which means that qdisc holds a reference to net_device + * and we hold a reference to qdisc, so it is safe to release + * rcu read lock. + */ + rcu_read_unlock(); + return err; + +errout_qdisc: + rcu_read_unlock(); + + if (rtnl_held) + qdisc_put(*q); + else + qdisc_put_unlocked(*q); + *q = NULL; + + return err; +} + +static int __tcf_qdisc_cl_find(struct Qdisc *q, u32 parent, unsigned long *cl, + int ifindex, struct netlink_ext_ack *extack) +{ + if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) + return 0; + + /* Do we search for filter, attached to class? */ + if (TC_H_MIN(parent)) { + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + + *cl = cops->find(q, parent); + if (*cl == 0) { + NL_SET_ERR_MSG(extack, "Specified class doesn't exist"); + return -ENOENT; } + } + + return 0; +} + +static struct tcf_block *__tcf_block_find(struct net *net, struct Qdisc *q, + unsigned long cl, int ifindex, + u32 block_index, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block; - /* And the last stroke */ - block = cops->tcf_block(*q, *cl, extack); + if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) { + block = tcf_block_refcnt_get(net, block_index); if (!block) { - err = -EINVAL; - goto errout_qdisc; + NL_SET_ERR_MSG(extack, "Block of given index was not found"); + return ERR_PTR(-EINVAL); } + } else { + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + + block = cops->tcf_block(q, cl, extack); + if (!block) + return ERR_PTR(-EINVAL); + if (tcf_block_shared(block)) { NL_SET_ERR_MSG(extack, "This filter block is shared. Please use the block index to manipulate the filters"); - err = -EOPNOTSUPP; - goto errout_qdisc; + return ERR_PTR(-EOPNOTSUPP); } /* Always take reference to block in order to support execution @@ -962,24 +1253,91 @@ static struct tcf_block *tcf_block_find(struct net *net, struct Qdisc **q, } return block; +} + +static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, + struct tcf_block_ext_info *ei, bool rtnl_held) +{ + if (refcount_dec_and_mutex_lock(&block->refcnt, &block->lock)) { + /* Flushing/putting all chains will cause the block to be + * deallocated when last chain is freed. However, if chain_list + * is empty, block has to be manually deallocated. After block + * reference counter reached 0, it is no longer possible to + * increment it or add new chains to block. + */ + bool free_block = list_empty(&block->chain_list); + + mutex_unlock(&block->lock); + if (tcf_block_shared(block)) + tcf_block_remove(block, block->net); + + if (q) + tcf_block_offload_unbind(block, q, ei); + + if (free_block) + tcf_block_destroy(block); + else + tcf_block_flush_all_chains(block, rtnl_held); + } else if (q) { + tcf_block_offload_unbind(block, q, ei); + } +} + +static void tcf_block_refcnt_put(struct tcf_block *block, bool rtnl_held) +{ + __tcf_block_put(block, NULL, NULL, rtnl_held); +} + +/* Find tcf block. + * Set q, parent, cl when appropriate. + */ + +static struct tcf_block *tcf_block_find(struct net *net, struct Qdisc **q, + u32 *parent, unsigned long *cl, + int ifindex, u32 block_index, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block; + int err = 0; + + ASSERT_RTNL(); + + err = __tcf_qdisc_find(net, q, parent, ifindex, true, extack); + if (err) + goto errout; + + err = __tcf_qdisc_cl_find(*q, *parent, cl, ifindex, extack); + if (err) + goto errout_qdisc; + + block = __tcf_block_find(net, *q, *cl, ifindex, block_index, extack); + if (IS_ERR(block)) { + err = PTR_ERR(block); + goto errout_qdisc; + } + + return block; -errout_rcu: - rcu_read_unlock(); errout_qdisc: - if (*q) { + if (*q) qdisc_put(*q); - *q = NULL; - } +errout: + *q = NULL; return ERR_PTR(err); } -static void tcf_block_release(struct Qdisc *q, struct tcf_block *block) +static void tcf_block_release(struct Qdisc *q, struct tcf_block *block, + bool rtnl_held) { if (!IS_ERR_OR_NULL(block)) - tcf_block_refcnt_put(block); + tcf_block_refcnt_put(block, rtnl_held); - if (q) - qdisc_put(q); + if (q) { + if (rtnl_held) + qdisc_put(q); + else + qdisc_put_unlocked(q); + } } struct tcf_block_owner_item { @@ -1087,7 +1445,7 @@ err_chain0_head_change_cb_add: tcf_block_owner_del(block, q, ei->binder_type); err_block_owner_add: err_block_insert: - tcf_block_refcnt_put(block); + tcf_block_refcnt_put(block, true); return err; } EXPORT_SYMBOL(tcf_block_get_ext); @@ -1124,7 +1482,7 @@ void tcf_block_put_ext(struct tcf_block *block, struct Qdisc *q, tcf_chain0_head_change_cb_del(block, ei); tcf_block_owner_del(block, q, ei->binder_type); - __tcf_block_put(block, q, ei); + __tcf_block_put(block, q, ei, true); } EXPORT_SYMBOL(tcf_block_put_ext); @@ -1181,13 +1539,19 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, void *cb_priv, bool add, bool offload_in_use, struct netlink_ext_ack *extack) { - struct tcf_chain *chain; - struct tcf_proto *tp; + struct tcf_chain *chain, *chain_prev; + struct tcf_proto *tp, *tp_prev; int err; - list_for_each_entry(chain, &block->chain_list, list) { - for (tp = rtnl_dereference(chain->filter_chain); tp; - tp = rtnl_dereference(tp->next)) { + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { + for (tp = __tcf_get_next_proto(chain, NULL); tp; + tp_prev = tp, + tp = __tcf_get_next_proto(chain, tp), + tcf_proto_put(tp_prev, true, NULL)) { if (tp->ops->reoffload) { err = tp->ops->reoffload(tp, add, cb, cb_priv, extack); @@ -1204,6 +1568,8 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, return 0; err_playback_remove: + tcf_proto_put(tp, true, NULL); + tcf_chain_put(chain); tcf_block_playback_offloads(block, cb, cb_priv, false, offload_in_use, extack); return err; @@ -1329,32 +1695,116 @@ struct tcf_chain_info { struct tcf_proto __rcu *next; }; -static struct tcf_proto *tcf_chain_tp_prev(struct tcf_chain_info *chain_info) +static struct tcf_proto *tcf_chain_tp_prev(struct tcf_chain *chain, + struct tcf_chain_info *chain_info) { - return rtnl_dereference(*chain_info->pprev); + return tcf_chain_dereference(*chain_info->pprev, chain); } -static void tcf_chain_tp_insert(struct tcf_chain *chain, - struct tcf_chain_info *chain_info, - struct tcf_proto *tp) +static int tcf_chain_tp_insert(struct tcf_chain *chain, + struct tcf_chain_info *chain_info, + struct tcf_proto *tp) { + if (chain->flushing) + return -EAGAIN; + if (*chain_info->pprev == chain->filter_chain) tcf_chain0_head_change(chain, tp); - RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain_info)); + tcf_proto_get(tp); + RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain, chain_info)); rcu_assign_pointer(*chain_info->pprev, tp); - tcf_chain_hold(chain); + + return 0; } static void tcf_chain_tp_remove(struct tcf_chain *chain, struct tcf_chain_info *chain_info, struct tcf_proto *tp) { - struct tcf_proto *next = rtnl_dereference(chain_info->next); + struct tcf_proto *next = tcf_chain_dereference(chain_info->next, chain); + tcf_proto_mark_delete(tp); if (tp == chain->filter_chain) tcf_chain0_head_change(chain, next); RCU_INIT_POINTER(*chain_info->pprev, next); - tcf_chain_put(chain); +} + +static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, + struct tcf_chain_info *chain_info, + u32 protocol, u32 prio, + bool prio_allocate); + +/* Try to insert new proto. + * If proto with specified priority already exists, free new proto + * and return existing one. + */ + +static struct tcf_proto *tcf_chain_tp_insert_unique(struct tcf_chain *chain, + struct tcf_proto *tp_new, + u32 protocol, u32 prio, + bool rtnl_held) +{ + struct tcf_chain_info chain_info; + struct tcf_proto *tp; + int err = 0; + + mutex_lock(&chain->filter_chain_lock); + + tp = tcf_chain_tp_find(chain, &chain_info, + protocol, prio, false); + if (!tp) + err = tcf_chain_tp_insert(chain, &chain_info, tp_new); + mutex_unlock(&chain->filter_chain_lock); + + if (tp) { + tcf_proto_destroy(tp_new, rtnl_held, NULL); + tp_new = tp; + } else if (err) { + tcf_proto_destroy(tp_new, rtnl_held, NULL); + tp_new = ERR_PTR(err); + } + + return tp_new; +} + +static void tcf_chain_tp_delete_empty(struct tcf_chain *chain, + struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct tcf_chain_info chain_info; + struct tcf_proto *tp_iter; + struct tcf_proto **pprev; + struct tcf_proto *next; + + mutex_lock(&chain->filter_chain_lock); + + /* Atomically find and remove tp from chain. */ + for (pprev = &chain->filter_chain; + (tp_iter = tcf_chain_dereference(*pprev, chain)); + pprev = &tp_iter->next) { + if (tp_iter == tp) { + chain_info.pprev = pprev; + chain_info.next = tp_iter->next; + WARN_ON(tp_iter->deleting); + break; + } + } + /* Verify that tp still exists and no new filters were inserted + * concurrently. + * Mark tp for deletion if it is empty. + */ + if (!tp_iter || !tcf_proto_check_delete(tp, rtnl_held)) { + mutex_unlock(&chain->filter_chain_lock); + return; + } + + next = tcf_chain_dereference(chain_info.next, chain); + if (tp == chain->filter_chain) + tcf_chain0_head_change(chain, next); + RCU_INIT_POINTER(*chain_info.pprev, next); + mutex_unlock(&chain->filter_chain_lock); + + tcf_proto_put(tp, rtnl_held, extack); } static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, @@ -1367,7 +1817,8 @@ static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, /* Check the chain for existence of proto-tcf with this priority */ for (pprev = &chain->filter_chain; - (tp = rtnl_dereference(*pprev)); pprev = &tp->next) { + (tp = tcf_chain_dereference(*pprev, chain)); + pprev = &tp->next) { if (tp->prio >= prio) { if (tp->prio == prio) { if (prio_allocate || @@ -1380,14 +1831,20 @@ static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, } } chain_info->pprev = pprev; - chain_info->next = tp ? tp->next : NULL; + if (tp) { + chain_info->next = tp->next; + tcf_proto_get(tp); + } else { + chain_info->next = NULL; + } return tp; } static int tcf_fill_node(struct net *net, struct sk_buff *skb, struct tcf_proto *tp, struct tcf_block *block, struct Qdisc *q, u32 parent, void *fh, - u32 portid, u32 seq, u16 flags, int event) + u32 portid, u32 seq, u16 flags, int event, + bool rtnl_held) { struct tcmsg *tcm; struct nlmsghdr *nlh; @@ -1415,7 +1872,8 @@ static int tcf_fill_node(struct net *net, struct sk_buff *skb, if (!fh) { tcm->tcm_handle = 0; } else { - if (tp->ops->dump && tp->ops->dump(net, tp, fh, skb, tcm) < 0) + if (tp->ops->dump && + tp->ops->dump(net, tp, fh, skb, tcm, rtnl_held) < 0) goto nla_put_failure; } nlh->nlmsg_len = skb_tail_pointer(skb) - b; @@ -1430,33 +1888,40 @@ nla_put_failure: static int tfilter_notify(struct net *net, struct sk_buff *oskb, struct nlmsghdr *n, struct tcf_proto *tp, struct tcf_block *block, struct Qdisc *q, - u32 parent, void *fh, int event, bool unicast) + u32 parent, void *fh, int event, bool unicast, + bool rtnl_held) { struct sk_buff *skb; u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + int err = 0; skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); if (!skb) return -ENOBUFS; if (tcf_fill_node(net, skb, tp, block, q, parent, fh, portid, - n->nlmsg_seq, n->nlmsg_flags, event) <= 0) { + n->nlmsg_seq, n->nlmsg_flags, event, + rtnl_held) <= 0) { kfree_skb(skb); return -EINVAL; } if (unicast) - return netlink_unicast(net->rtnl, skb, portid, MSG_DONTWAIT); + err = netlink_unicast(net->rtnl, skb, portid, MSG_DONTWAIT); + else + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); - return rtnetlink_send(skb, net, portid, RTNLGRP_TC, - n->nlmsg_flags & NLM_F_ECHO); + if (err > 0) + err = 0; + return err; } static int tfilter_del_notify(struct net *net, struct sk_buff *oskb, struct nlmsghdr *n, struct tcf_proto *tp, struct tcf_block *block, struct Qdisc *q, u32 parent, void *fh, bool unicast, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct sk_buff *skb; u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; @@ -1467,39 +1932,50 @@ static int tfilter_del_notify(struct net *net, struct sk_buff *oskb, return -ENOBUFS; if (tcf_fill_node(net, skb, tp, block, q, parent, fh, portid, - n->nlmsg_seq, n->nlmsg_flags, RTM_DELTFILTER) <= 0) { + n->nlmsg_seq, n->nlmsg_flags, RTM_DELTFILTER, + rtnl_held) <= 0) { NL_SET_ERR_MSG(extack, "Failed to build del event notification"); kfree_skb(skb); return -EINVAL; } - err = tp->ops->delete(tp, fh, last, extack); + err = tp->ops->delete(tp, fh, last, rtnl_held, extack); if (err) { kfree_skb(skb); return err; } if (unicast) - return netlink_unicast(net->rtnl, skb, portid, MSG_DONTWAIT); - - err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, - n->nlmsg_flags & NLM_F_ECHO); + err = netlink_unicast(net->rtnl, skb, portid, MSG_DONTWAIT); + else + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); if (err < 0) NL_SET_ERR_MSG(extack, "Failed to send filter delete notification"); + + if (err > 0) + err = 0; return err; } static void tfilter_notify_chain(struct net *net, struct sk_buff *oskb, struct tcf_block *block, struct Qdisc *q, u32 parent, struct nlmsghdr *n, - struct tcf_chain *chain, int event) + struct tcf_chain *chain, int event, + bool rtnl_held) { struct tcf_proto *tp; - for (tp = rtnl_dereference(chain->filter_chain); - tp; tp = rtnl_dereference(tp->next)) + for (tp = tcf_get_next_proto(chain, NULL, rtnl_held); + tp; tp = tcf_get_next_proto(chain, tp, rtnl_held)) tfilter_notify(net, oskb, n, tp, block, - q, parent, NULL, event, false); + q, parent, NULL, event, false, rtnl_held); +} + +static void tfilter_put(struct tcf_proto *tp, void *fh) +{ + if (tp->ops->put && fh) + tp->ops->put(tp, fh); } static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, @@ -1522,6 +1998,7 @@ static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, void *fh; int err; int tp_created; + bool rtnl_held = false; if (!netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) return -EPERM; @@ -1538,7 +2015,9 @@ replay: prio = TC_H_MAJ(t->tcm_info); prio_allocate = false; parent = t->tcm_parent; + tp = NULL; cl = 0; + block = NULL; if (prio == 0) { /* If no priority is provided by the user, @@ -1555,8 +2034,27 @@ replay: /* Find head of filter chain. */ - block = tcf_block_find(net, &q, &parent, &cl, - t->tcm_ifindex, t->tcm_block_index, extack); + err = __tcf_qdisc_find(net, &q, &parent, t->tcm_ifindex, false, extack); + if (err) + return err; + + /* Take rtnl mutex if rtnl_held was set to true on previous iteration, + * block is shared (no qdisc found), qdisc is not unlocked, classifier + * type is not specified, classifier is not unlocked. + */ + if (rtnl_held || + (q && !(q->ops->cl_ops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED)) || + !tca[TCA_KIND] || !tcf_proto_is_unlocked(nla_data(tca[TCA_KIND]))) { + rtnl_held = true; + rtnl_lock(); + } + + err = __tcf_qdisc_cl_find(q, parent, &cl, t->tcm_ifindex, extack); + if (err) + goto errout; + + block = __tcf_block_find(net, q, cl, t->tcm_ifindex, t->tcm_block_index, + extack); if (IS_ERR(block)) { err = PTR_ERR(block); goto errout; @@ -1575,40 +2073,62 @@ replay: goto errout; } + mutex_lock(&chain->filter_chain_lock); tp = tcf_chain_tp_find(chain, &chain_info, protocol, prio, prio_allocate); if (IS_ERR(tp)) { NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); err = PTR_ERR(tp); - goto errout; + goto errout_locked; } if (tp == NULL) { + struct tcf_proto *tp_new = NULL; + + if (chain->flushing) { + err = -EAGAIN; + goto errout_locked; + } + /* Proto-tcf does not exist, create new one */ if (tca[TCA_KIND] == NULL || !protocol) { NL_SET_ERR_MSG(extack, "Filter kind and protocol must be specified"); err = -EINVAL; - goto errout; + goto errout_locked; } if (!(n->nlmsg_flags & NLM_F_CREATE)) { NL_SET_ERR_MSG(extack, "Need both RTM_NEWTFILTER and NLM_F_CREATE to create a new filter"); err = -ENOENT; - goto errout; + goto errout_locked; } if (prio_allocate) - prio = tcf_auto_prio(tcf_chain_tp_prev(&chain_info)); + prio = tcf_auto_prio(tcf_chain_tp_prev(chain, + &chain_info)); - tp = tcf_proto_create(nla_data(tca[TCA_KIND]), - protocol, prio, chain, extack); + mutex_unlock(&chain->filter_chain_lock); + tp_new = tcf_proto_create(nla_data(tca[TCA_KIND]), + protocol, prio, chain, rtnl_held, + extack); + if (IS_ERR(tp_new)) { + err = PTR_ERR(tp_new); + goto errout_tp; + } + + tp_created = 1; + tp = tcf_chain_tp_insert_unique(chain, tp_new, protocol, prio, + rtnl_held); if (IS_ERR(tp)) { err = PTR_ERR(tp); - goto errout; + goto errout_tp; } - tp_created = 1; - } else if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { + } else { + mutex_unlock(&chain->filter_chain_lock); + } + + if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); err = -EINVAL; goto errout; @@ -1623,6 +2143,7 @@ replay: goto errout; } } else if (n->nlmsg_flags & NLM_F_EXCL) { + tfilter_put(tp, fh); NL_SET_ERR_MSG(extack, "Filter already exists"); err = -EEXIST; goto errout; @@ -1636,25 +2157,41 @@ replay: err = tp->ops->change(net, skb, tp, cl, t->tcm_handle, tca, &fh, n->nlmsg_flags & NLM_F_CREATE ? TCA_ACT_NOREPLACE : TCA_ACT_REPLACE, - extack); + rtnl_held, extack); if (err == 0) { - if (tp_created) - tcf_chain_tp_insert(chain, &chain_info, tp); tfilter_notify(net, skb, n, tp, block, q, parent, fh, - RTM_NEWTFILTER, false); - } else { - if (tp_created) - tcf_proto_destroy(tp, NULL); + RTM_NEWTFILTER, false, rtnl_held); + tfilter_put(tp, fh); } errout: - if (chain) - tcf_chain_put(chain); - tcf_block_release(q, block); - if (err == -EAGAIN) + if (err && tp_created) + tcf_chain_tp_delete_empty(chain, tp, rtnl_held, NULL); +errout_tp: + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, rtnl_held, NULL); + if (!tp_created) + tcf_chain_put(chain); + } + tcf_block_release(q, block, rtnl_held); + + if (rtnl_held) + rtnl_unlock(); + + if (err == -EAGAIN) { + /* Take rtnl lock in case EAGAIN is caused by concurrent flush + * of target chain. + */ + rtnl_held = true; /* Replay the request. */ goto replay; + } return err; + +errout_locked: + mutex_unlock(&chain->filter_chain_lock); + goto errout; } static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, @@ -1670,11 +2207,12 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, struct Qdisc *q = NULL; struct tcf_chain_info chain_info; struct tcf_chain *chain = NULL; - struct tcf_block *block; + struct tcf_block *block = NULL; struct tcf_proto *tp = NULL; unsigned long cl = 0; void *fh = NULL; int err; + bool rtnl_held = false; if (!netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) return -EPERM; @@ -1695,8 +2233,27 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, /* Find head of filter chain. */ - block = tcf_block_find(net, &q, &parent, &cl, - t->tcm_ifindex, t->tcm_block_index, extack); + err = __tcf_qdisc_find(net, &q, &parent, t->tcm_ifindex, false, extack); + if (err) + return err; + + /* Take rtnl mutex if flushing whole chain, block is shared (no qdisc + * found), qdisc is not unlocked, classifier type is not specified, + * classifier is not unlocked. + */ + if (!prio || + (q && !(q->ops->cl_ops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED)) || + !tca[TCA_KIND] || !tcf_proto_is_unlocked(nla_data(tca[TCA_KIND]))) { + rtnl_held = true; + rtnl_lock(); + } + + err = __tcf_qdisc_cl_find(q, parent, &cl, t->tcm_ifindex, extack); + if (err) + goto errout; + + block = __tcf_block_find(net, q, cl, t->tcm_ifindex, t->tcm_block_index, + extack); if (IS_ERR(block)) { err = PTR_ERR(block); goto errout; @@ -1724,56 +2281,69 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, if (prio == 0) { tfilter_notify_chain(net, skb, block, q, parent, n, - chain, RTM_DELTFILTER); - tcf_chain_flush(chain); + chain, RTM_DELTFILTER, rtnl_held); + tcf_chain_flush(chain, rtnl_held); err = 0; goto errout; } + mutex_lock(&chain->filter_chain_lock); tp = tcf_chain_tp_find(chain, &chain_info, protocol, prio, false); if (!tp || IS_ERR(tp)) { NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); err = tp ? PTR_ERR(tp) : -ENOENT; - goto errout; + goto errout_locked; } else if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); err = -EINVAL; + goto errout_locked; + } else if (t->tcm_handle == 0) { + tcf_chain_tp_remove(chain, &chain_info, tp); + mutex_unlock(&chain->filter_chain_lock); + + tcf_proto_put(tp, rtnl_held, NULL); + tfilter_notify(net, skb, n, tp, block, q, parent, fh, + RTM_DELTFILTER, false, rtnl_held); + err = 0; goto errout; } + mutex_unlock(&chain->filter_chain_lock); fh = tp->ops->get(tp, t->tcm_handle); if (!fh) { - if (t->tcm_handle == 0) { - tcf_chain_tp_remove(chain, &chain_info, tp); - tfilter_notify(net, skb, n, tp, block, q, parent, fh, - RTM_DELTFILTER, false); - tcf_proto_destroy(tp, extack); - err = 0; - } else { - NL_SET_ERR_MSG(extack, "Specified filter handle not found"); - err = -ENOENT; - } + NL_SET_ERR_MSG(extack, "Specified filter handle not found"); + err = -ENOENT; } else { bool last; err = tfilter_del_notify(net, skb, n, tp, block, q, parent, fh, false, &last, - extack); + rtnl_held, extack); + if (err) goto errout; - if (last) { - tcf_chain_tp_remove(chain, &chain_info, tp); - tcf_proto_destroy(tp, extack); - } + if (last) + tcf_chain_tp_delete_empty(chain, tp, rtnl_held, extack); } errout: - if (chain) + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, rtnl_held, NULL); tcf_chain_put(chain); - tcf_block_release(q, block); + } + tcf_block_release(q, block, rtnl_held); + + if (rtnl_held) + rtnl_unlock(); + return err; + +errout_locked: + mutex_unlock(&chain->filter_chain_lock); + goto errout; } static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, @@ -1789,11 +2359,12 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, struct Qdisc *q = NULL; struct tcf_chain_info chain_info; struct tcf_chain *chain = NULL; - struct tcf_block *block; + struct tcf_block *block = NULL; struct tcf_proto *tp = NULL; unsigned long cl = 0; void *fh = NULL; int err; + bool rtnl_held = false; err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, rtm_tca_policy, extack); if (err < 0) @@ -1811,8 +2382,26 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, /* Find head of filter chain. */ - block = tcf_block_find(net, &q, &parent, &cl, - t->tcm_ifindex, t->tcm_block_index, extack); + err = __tcf_qdisc_find(net, &q, &parent, t->tcm_ifindex, false, extack); + if (err) + return err; + + /* Take rtnl mutex if block is shared (no qdisc found), qdisc is not + * unlocked, classifier type is not specified, classifier is not + * unlocked. + */ + if ((q && !(q->ops->cl_ops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED)) || + !tca[TCA_KIND] || !tcf_proto_is_unlocked(nla_data(tca[TCA_KIND]))) { + rtnl_held = true; + rtnl_lock(); + } + + err = __tcf_qdisc_cl_find(q, parent, &cl, t->tcm_ifindex, extack); + if (err) + goto errout; + + block = __tcf_block_find(net, q, cl, t->tcm_ifindex, t->tcm_block_index, + extack); if (IS_ERR(block)) { err = PTR_ERR(block); goto errout; @@ -1831,8 +2420,10 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, goto errout; } + mutex_lock(&chain->filter_chain_lock); tp = tcf_chain_tp_find(chain, &chain_info, protocol, prio, false); + mutex_unlock(&chain->filter_chain_lock); if (!tp || IS_ERR(tp)) { NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); err = tp ? PTR_ERR(tp) : -ENOENT; @@ -1850,15 +2441,23 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, err = -ENOENT; } else { err = tfilter_notify(net, skb, n, tp, block, q, parent, - fh, RTM_NEWTFILTER, true); + fh, RTM_NEWTFILTER, true, rtnl_held); if (err < 0) NL_SET_ERR_MSG(extack, "Failed to send filter notify message"); } + tfilter_put(tp, fh); errout: - if (chain) + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, rtnl_held, NULL); tcf_chain_put(chain); - tcf_block_release(q, block); + } + tcf_block_release(q, block, rtnl_held); + + if (rtnl_held) + rtnl_unlock(); + return err; } @@ -1879,7 +2478,7 @@ static int tcf_node_dump(struct tcf_proto *tp, void *n, struct tcf_walker *arg) return tcf_fill_node(net, a->skb, tp, a->block, a->q, a->parent, n, NETLINK_CB(a->cb->skb).portid, a->cb->nlh->nlmsg_seq, NLM_F_MULTI, - RTM_NEWTFILTER); + RTM_NEWTFILTER, true); } static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, @@ -1889,11 +2488,15 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, struct net *net = sock_net(skb->sk); struct tcf_block *block = chain->block; struct tcmsg *tcm = nlmsg_data(cb->nlh); + struct tcf_proto *tp, *tp_prev; struct tcf_dump_args arg; - struct tcf_proto *tp; - for (tp = rtnl_dereference(chain->filter_chain); - tp; tp = rtnl_dereference(tp->next), (*p_index)++) { + for (tp = __tcf_get_next_proto(chain, NULL); + tp; + tp_prev = tp, + tp = __tcf_get_next_proto(chain, tp), + tcf_proto_put(tp_prev, true, NULL), + (*p_index)++) { if (*p_index < index_start) continue; if (TC_H_MAJ(tcm->tcm_info) && @@ -1909,9 +2512,8 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, if (tcf_fill_node(net, skb, tp, block, q, parent, NULL, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI, - RTM_NEWTFILTER) <= 0) - return false; - + RTM_NEWTFILTER, true) <= 0) + goto errout; cb->args[1] = 1; } if (!tp->ops->walk) @@ -1926,23 +2528,27 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, arg.w.skip = cb->args[1] - 1; arg.w.count = 0; arg.w.cookie = cb->args[2]; - tp->ops->walk(tp, &arg.w); + tp->ops->walk(tp, &arg.w, true); cb->args[2] = arg.w.cookie; cb->args[1] = arg.w.count + 1; if (arg.w.stop) - return false; + goto errout; } return true; + +errout: + tcf_proto_put(tp, true, NULL); + return false; } /* called with RTNL */ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) { + struct tcf_chain *chain, *chain_prev; struct net *net = sock_net(skb->sk); struct nlattr *tca[TCA_MAX + 1]; struct Qdisc *q = NULL; struct tcf_block *block; - struct tcf_chain *chain; struct tcmsg *tcm = nlmsg_data(cb->nlh); long index_start; long index; @@ -2006,19 +2612,24 @@ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) index_start = cb->args[0]; index = 0; - list_for_each_entry(chain, &block->chain_list, list) { + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { if (tca[TCA_CHAIN] && nla_get_u32(tca[TCA_CHAIN]) != chain->index) continue; if (!tcf_chain_dump(chain, q, parent, skb, cb, index_start, &index)) { + tcf_chain_put(chain); err = -EMSGSIZE; break; } } if (tcm->tcm_ifindex == TCM_IFINDEX_MAGIC_BLOCK) - tcf_block_refcnt_put(block); + tcf_block_refcnt_put(block, true); cb->args[0] = index; out: @@ -2028,8 +2639,10 @@ out: return skb->len; } -static int tc_chain_fill_node(struct tcf_chain *chain, struct net *net, - struct sk_buff *skb, struct tcf_block *block, +static int tc_chain_fill_node(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv, u32 chain_index, + struct net *net, struct sk_buff *skb, + struct tcf_block *block, u32 portid, u32 seq, u16 flags, int event) { unsigned char *b = skb_tail_pointer(skb); @@ -2038,8 +2651,8 @@ static int tc_chain_fill_node(struct tcf_chain *chain, struct net *net, struct tcmsg *tcm; void *priv; - ops = chain->tmplt_ops; - priv = chain->tmplt_priv; + ops = tmplt_ops; + priv = tmplt_priv; nlh = nlmsg_put(skb, portid, seq, event, sizeof(*tcm), flags); if (!nlh) @@ -2057,7 +2670,7 @@ static int tc_chain_fill_node(struct tcf_chain *chain, struct net *net, tcm->tcm_block_index = block->index; } - if (nla_put_u32(skb, TCA_CHAIN, chain->index)) + if (nla_put_u32(skb, TCA_CHAIN, chain_index)) goto nla_put_failure; if (ops) { @@ -2083,18 +2696,50 @@ static int tc_chain_notify(struct tcf_chain *chain, struct sk_buff *oskb, struct tcf_block *block = chain->block; struct net *net = block->net; struct sk_buff *skb; + int err = 0; skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); if (!skb) return -ENOBUFS; - if (tc_chain_fill_node(chain, net, skb, block, portid, + if (tc_chain_fill_node(chain->tmplt_ops, chain->tmplt_priv, + chain->index, net, skb, block, portid, seq, flags, event) <= 0) { kfree_skb(skb); return -EINVAL; } if (unicast) + err = netlink_unicast(net->rtnl, skb, portid, MSG_DONTWAIT); + else + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + flags & NLM_F_ECHO); + + if (err > 0) + err = 0; + return err; +} + +static int tc_chain_notify_delete(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv, u32 chain_index, + struct tcf_block *block, struct sk_buff *oskb, + u32 seq, u16 flags, bool unicast) +{ + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + struct net *net = block->net; + struct sk_buff *skb; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tc_chain_fill_node(tmplt_ops, tmplt_priv, chain_index, net, skb, + block, portid, seq, flags, RTM_DELCHAIN) <= 0) { + kfree_skb(skb); + return -EINVAL; + } + + if (unicast) return netlink_unicast(net->rtnl, skb, portid, MSG_DONTWAIT); return rtnetlink_send(skb, net, portid, RTNLGRP_TC, flags & NLM_F_ECHO); @@ -2111,7 +2756,7 @@ static int tc_chain_tmplt_add(struct tcf_chain *chain, struct net *net, if (!tca[TCA_KIND]) return 0; - ops = tcf_proto_lookup_ops(nla_data(tca[TCA_KIND]), extack); + ops = tcf_proto_lookup_ops(nla_data(tca[TCA_KIND]), true, extack); if (IS_ERR(ops)) return PTR_ERR(ops); if (!ops->tmplt_create || !ops->tmplt_destroy || !ops->tmplt_dump) { @@ -2129,16 +2774,15 @@ static int tc_chain_tmplt_add(struct tcf_chain *chain, struct net *net, return 0; } -static void tc_chain_tmplt_del(struct tcf_chain *chain) +static void tc_chain_tmplt_del(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv) { - const struct tcf_proto_ops *ops = chain->tmplt_ops; - /* If template ops are set, no work to do for us. */ - if (!ops) + if (!tmplt_ops) return; - ops->tmplt_destroy(chain->tmplt_priv); - module_put(ops->owner); + tmplt_ops->tmplt_destroy(tmplt_priv); + module_put(tmplt_ops->owner); } /* Add/delete/get a chain */ @@ -2181,6 +2825,8 @@ replay: err = -EINVAL; goto errout_block; } + + mutex_lock(&block->lock); chain = tcf_chain_lookup(block, chain_index); if (n->nlmsg_type == RTM_NEWCHAIN) { if (chain) { @@ -2192,54 +2838,61 @@ replay: } else { NL_SET_ERR_MSG(extack, "Filter chain already exists"); err = -EEXIST; - goto errout_block; + goto errout_block_locked; } } else { if (!(n->nlmsg_flags & NLM_F_CREATE)) { NL_SET_ERR_MSG(extack, "Need both RTM_NEWCHAIN and NLM_F_CREATE to create a new chain"); err = -ENOENT; - goto errout_block; + goto errout_block_locked; } chain = tcf_chain_create(block, chain_index); if (!chain) { NL_SET_ERR_MSG(extack, "Failed to create filter chain"); err = -ENOMEM; - goto errout_block; + goto errout_block_locked; } } } else { if (!chain || tcf_chain_held_by_acts_only(chain)) { NL_SET_ERR_MSG(extack, "Cannot find specified filter chain"); err = -EINVAL; - goto errout_block; + goto errout_block_locked; } tcf_chain_hold(chain); } + if (n->nlmsg_type == RTM_NEWCHAIN) { + /* Modifying chain requires holding parent block lock. In case + * the chain was successfully added, take a reference to the + * chain. This ensures that an empty chain does not disappear at + * the end of this function. + */ + tcf_chain_hold(chain); + chain->explicitly_created = true; + } + mutex_unlock(&block->lock); + switch (n->nlmsg_type) { case RTM_NEWCHAIN: err = tc_chain_tmplt_add(chain, net, tca, extack); - if (err) + if (err) { + tcf_chain_put_explicitly_created(chain); goto errout; - /* In case the chain was successfully added, take a reference - * to the chain. This ensures that an empty chain - * does not disappear at the end of this function. - */ - tcf_chain_hold(chain); - chain->explicitly_created = true; + } + tc_chain_notify(chain, NULL, 0, NLM_F_CREATE | NLM_F_EXCL, RTM_NEWCHAIN, false); break; case RTM_DELCHAIN: tfilter_notify_chain(net, skb, block, q, parent, n, - chain, RTM_DELTFILTER); + chain, RTM_DELTFILTER, true); /* Flush the chain first as the user requested chain removal. */ - tcf_chain_flush(chain); + tcf_chain_flush(chain, true); /* In case the chain was successfully deleted, put a reference * to the chain previously taken during addition. */ tcf_chain_put_explicitly_created(chain); - chain->explicitly_created = false; break; case RTM_GETCHAIN: err = tc_chain_notify(chain, skb, n->nlmsg_seq, @@ -2256,11 +2909,15 @@ replay: errout: tcf_chain_put(chain); errout_block: - tcf_block_release(q, block); + tcf_block_release(q, block, true); if (err == -EAGAIN) /* Replay the request. */ goto replay; return err; + +errout_block_locked: + mutex_unlock(&block->lock); + goto errout_block; } /* called with RTNL */ @@ -2270,8 +2927,8 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) struct nlattr *tca[TCA_MAX + 1]; struct Qdisc *q = NULL; struct tcf_block *block; - struct tcf_chain *chain; struct tcmsg *tcm = nlmsg_data(cb->nlh); + struct tcf_chain *chain; long index_start; long index; u32 parent; @@ -2334,6 +2991,7 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) index_start = cb->args[0]; index = 0; + mutex_lock(&block->lock); list_for_each_entry(chain, &block->chain_list, list) { if ((tca[TCA_CHAIN] && nla_get_u32(tca[TCA_CHAIN]) != chain->index)) @@ -2344,7 +3002,8 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) } if (tcf_chain_held_by_acts_only(chain)) continue; - err = tc_chain_fill_node(chain, net, skb, block, + err = tc_chain_fill_node(chain->tmplt_ops, chain->tmplt_priv, + chain->index, net, skb, block, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI, RTM_NEWCHAIN); @@ -2352,9 +3011,10 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) break; index++; } + mutex_unlock(&block->lock); if (tcm->tcm_ifindex == TCM_IFINDEX_MAGIC_BLOCK) - tcf_block_refcnt_put(block); + tcf_block_refcnt_put(block, true); cb->args[0] = index; out: @@ -2376,7 +3036,7 @@ EXPORT_SYMBOL(tcf_exts_destroy); int tcf_exts_validate(struct net *net, struct tcf_proto *tp, struct nlattr **tb, struct nlattr *rate_tlv, struct tcf_exts *exts, bool ovr, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { #ifdef CONFIG_NET_CLS_ACT { @@ -2386,7 +3046,8 @@ int tcf_exts_validate(struct net *net, struct tcf_proto *tp, struct nlattr **tb, if (exts->police && tb[exts->police]) { act = tcf_action_init_1(net, tp, tb[exts->police], rate_tlv, "police", ovr, - TCA_ACT_BIND, true, extack); + TCA_ACT_BIND, rtnl_held, + extack); if (IS_ERR(act)) return PTR_ERR(act); @@ -2398,13 +3059,12 @@ int tcf_exts_validate(struct net *net, struct tcf_proto *tp, struct nlattr **tb, err = tcf_action_init(net, tp, tb[exts->action], rate_tlv, NULL, ovr, TCA_ACT_BIND, - exts->actions, &attr_size, true, - extack); + exts->actions, &attr_size, + rtnl_held, extack); if (err < 0) return err; exts->nr_actions = err; } - exts->net = net; } #else if ((exts->action && tb[exts->action]) || @@ -2515,6 +3175,114 @@ int tc_setup_cb_call(struct tcf_block *block, enum tc_setup_type type, } EXPORT_SYMBOL(tc_setup_cb_call); +int tc_setup_flow_action(struct flow_action *flow_action, + const struct tcf_exts *exts) +{ + const struct tc_action *act; + int i, j, k; + + if (!exts) + return 0; + + j = 0; + tcf_exts_for_each_action(i, act, exts) { + struct flow_action_entry *entry; + + entry = &flow_action->entries[j]; + if (is_tcf_gact_ok(act)) { + entry->id = FLOW_ACTION_ACCEPT; + } else if (is_tcf_gact_shot(act)) { + entry->id = FLOW_ACTION_DROP; + } else if (is_tcf_gact_trap(act)) { + entry->id = FLOW_ACTION_TRAP; + } else if (is_tcf_gact_goto_chain(act)) { + entry->id = FLOW_ACTION_GOTO; + entry->chain_index = tcf_gact_goto_chain_index(act); + } else if (is_tcf_mirred_egress_redirect(act)) { + entry->id = FLOW_ACTION_REDIRECT; + entry->dev = tcf_mirred_dev(act); + } else if (is_tcf_mirred_egress_mirror(act)) { + entry->id = FLOW_ACTION_MIRRED; + entry->dev = tcf_mirred_dev(act); + } else if (is_tcf_vlan(act)) { + switch (tcf_vlan_action(act)) { + case TCA_VLAN_ACT_PUSH: + entry->id = FLOW_ACTION_VLAN_PUSH; + entry->vlan.vid = tcf_vlan_push_vid(act); + entry->vlan.proto = tcf_vlan_push_proto(act); + entry->vlan.prio = tcf_vlan_push_prio(act); + break; + case TCA_VLAN_ACT_POP: + entry->id = FLOW_ACTION_VLAN_POP; + break; + case TCA_VLAN_ACT_MODIFY: + entry->id = FLOW_ACTION_VLAN_MANGLE; + entry->vlan.vid = tcf_vlan_push_vid(act); + entry->vlan.proto = tcf_vlan_push_proto(act); + entry->vlan.prio = tcf_vlan_push_prio(act); + break; + default: + goto err_out; + } + } else if (is_tcf_tunnel_set(act)) { + entry->id = FLOW_ACTION_TUNNEL_ENCAP; + entry->tunnel = tcf_tunnel_info(act); + } else if (is_tcf_tunnel_release(act)) { + entry->id = FLOW_ACTION_TUNNEL_DECAP; + entry->tunnel = tcf_tunnel_info(act); + } else if (is_tcf_pedit(act)) { + for (k = 0; k < tcf_pedit_nkeys(act); k++) { + switch (tcf_pedit_cmd(act, k)) { + case TCA_PEDIT_KEY_EX_CMD_SET: + entry->id = FLOW_ACTION_MANGLE; + break; + case TCA_PEDIT_KEY_EX_CMD_ADD: + entry->id = FLOW_ACTION_ADD; + break; + default: + goto err_out; + } + entry->mangle.htype = tcf_pedit_htype(act, k); + entry->mangle.mask = tcf_pedit_mask(act, k); + entry->mangle.val = tcf_pedit_val(act, k); + entry->mangle.offset = tcf_pedit_offset(act, k); + entry = &flow_action->entries[++j]; + } + } else if (is_tcf_csum(act)) { + entry->id = FLOW_ACTION_CSUM; + entry->csum_flags = tcf_csum_update_flags(act); + } else if (is_tcf_skbedit_mark(act)) { + entry->id = FLOW_ACTION_MARK; + entry->mark = tcf_skbedit_mark(act); + } else { + goto err_out; + } + + if (!is_tcf_pedit(act)) + j++; + } + return 0; +err_out: + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(tc_setup_flow_action); + +unsigned int tcf_exts_num_actions(struct tcf_exts *exts) +{ + unsigned int num_acts = 0; + struct tc_action *act; + int i; + + tcf_exts_for_each_action(i, act, exts) { + if (is_tcf_pedit(act)) + num_acts += tcf_pedit_nkeys(act); + else + num_acts++; + } + return num_acts; +} +EXPORT_SYMBOL(tcf_exts_num_actions); + static __net_init int tcf_net_init(struct net *net) { struct tcf_net *tn = net_generic(net, tcf_net_id); @@ -2555,10 +3323,12 @@ static int __init tc_filter_init(void) if (err) goto err_rhash_setup_block_ht; - rtnl_register(PF_UNSPEC, RTM_NEWTFILTER, tc_new_tfilter, NULL, 0); - rtnl_register(PF_UNSPEC, RTM_DELTFILTER, tc_del_tfilter, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_NEWTFILTER, tc_new_tfilter, NULL, + RTNL_FLAG_DOIT_UNLOCKED); + rtnl_register(PF_UNSPEC, RTM_DELTFILTER, tc_del_tfilter, NULL, + RTNL_FLAG_DOIT_UNLOCKED); rtnl_register(PF_UNSPEC, RTM_GETTFILTER, tc_get_tfilter, - tc_dump_tfilter, 0); + tc_dump_tfilter, RTNL_FLAG_DOIT_UNLOCKED); rtnl_register(PF_UNSPEC, RTM_NEWCHAIN, tc_ctl_chain, NULL, 0); rtnl_register(PF_UNSPEC, RTM_DELCHAIN, tc_ctl_chain, NULL, 0); rtnl_register(PF_UNSPEC, RTM_GETCHAIN, tc_ctl_chain, diff --git a/net/sched/cls_basic.c b/net/sched/cls_basic.c index 6a5dce8baf19..687b0af67878 100644 --- a/net/sched/cls_basic.c +++ b/net/sched/cls_basic.c @@ -18,6 +18,7 @@ #include <linux/rtnetlink.h> #include <linux/skbuff.h> #include <linux/idr.h> +#include <linux/percpu.h> #include <net/netlink.h> #include <net/act_api.h> #include <net/pkt_cls.h> @@ -35,6 +36,7 @@ struct basic_filter { struct tcf_result res; struct tcf_proto *tp; struct list_head link; + struct tc_basic_pcnt __percpu *pf; struct rcu_work rwork; }; @@ -46,8 +48,10 @@ static int basic_classify(struct sk_buff *skb, const struct tcf_proto *tp, struct basic_filter *f; list_for_each_entry_rcu(f, &head->flist, link) { + __this_cpu_inc(f->pf->rcnt); if (!tcf_em_tree_match(skb, &f->ematches, NULL)) continue; + __this_cpu_inc(f->pf->rhit); *res = f->res; r = tcf_exts_exec(skb, &f->exts, res); if (r < 0) @@ -89,6 +93,7 @@ static void __basic_delete_filter(struct basic_filter *f) tcf_exts_destroy(&f->exts); tcf_em_tree_destroy(&f->ematches); tcf_exts_put_net(&f->exts); + free_percpu(f->pf); kfree(f); } @@ -102,7 +107,8 @@ static void basic_delete_filter_work(struct work_struct *work) rtnl_unlock(); } -static void basic_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void basic_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct basic_head *head = rtnl_dereference(tp->root); struct basic_filter *f, *n; @@ -121,7 +127,7 @@ static void basic_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) } static int basic_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct basic_head *head = rtnl_dereference(tp->root); struct basic_filter *f = arg; @@ -148,7 +154,7 @@ static int basic_set_parms(struct net *net, struct tcf_proto *tp, { int err; - err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, extack); + err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, true, extack); if (err < 0) return err; @@ -168,7 +174,7 @@ static int basic_set_parms(struct net *net, struct tcf_proto *tp, static int basic_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, void **arg, bool ovr, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { int err; struct basic_head *head = rtnl_dereference(tp->root); @@ -193,7 +199,7 @@ static int basic_change(struct net *net, struct sk_buff *in_skb, if (!fnew) return -ENOBUFS; - err = tcf_exts_init(&fnew->exts, TCA_BASIC_ACT, TCA_BASIC_POLICE); + err = tcf_exts_init(&fnew->exts, net, TCA_BASIC_ACT, TCA_BASIC_POLICE); if (err < 0) goto errout; @@ -208,6 +214,11 @@ static int basic_change(struct net *net, struct sk_buff *in_skb, if (err) goto errout; fnew->handle = handle; + fnew->pf = alloc_percpu(struct tc_basic_pcnt); + if (!fnew->pf) { + err = -ENOMEM; + goto errout; + } err = basic_set_parms(net, tp, fnew, base, tb, tca[TCA_RATE], ovr, extack); @@ -231,12 +242,14 @@ static int basic_change(struct net *net, struct sk_buff *in_skb, return 0; errout: + free_percpu(fnew->pf); tcf_exts_destroy(&fnew->exts); kfree(fnew); return err; } -static void basic_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void basic_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct basic_head *head = rtnl_dereference(tp->root); struct basic_filter *f; @@ -263,10 +276,12 @@ static void basic_bind_class(void *fh, u32 classid, unsigned long cl) } static int basic_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { + struct tc_basic_pcnt gpf = {}; struct basic_filter *f = fh; struct nlattr *nest; + int cpu; if (f == NULL) return skb->len; @@ -281,6 +296,18 @@ static int basic_dump(struct net *net, struct tcf_proto *tp, void *fh, nla_put_u32(skb, TCA_BASIC_CLASSID, f->res.classid)) goto nla_put_failure; + for_each_possible_cpu(cpu) { + struct tc_basic_pcnt *pf = per_cpu_ptr(f->pf, cpu); + + gpf.rcnt += pf->rcnt; + gpf.rhit += pf->rhit; + } + + if (nla_put_64bit(skb, TCA_BASIC_PCNT, + sizeof(struct tc_basic_pcnt), + &gpf, TCA_BASIC_PAD)) + goto nla_put_failure; + if (tcf_exts_dump(skb, &f->exts) < 0 || tcf_em_tree_dump(skb, &f->ematches, TCA_BASIC_EMATCHES) < 0) goto nla_put_failure; diff --git a/net/sched/cls_bpf.c b/net/sched/cls_bpf.c index a95cb240a606..b4ac58039cb1 100644 --- a/net/sched/cls_bpf.c +++ b/net/sched/cls_bpf.c @@ -298,7 +298,7 @@ static void __cls_bpf_delete(struct tcf_proto *tp, struct cls_bpf_prog *prog, } static int cls_bpf_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct cls_bpf_head *head = rtnl_dereference(tp->root); @@ -307,7 +307,7 @@ static int cls_bpf_delete(struct tcf_proto *tp, void *arg, bool *last, return 0; } -static void cls_bpf_destroy(struct tcf_proto *tp, +static void cls_bpf_destroy(struct tcf_proto *tp, bool rtnl_held, struct netlink_ext_ack *extack) { struct cls_bpf_head *head = rtnl_dereference(tp->root); @@ -417,7 +417,8 @@ static int cls_bpf_set_parms(struct net *net, struct tcf_proto *tp, if ((!is_bpf && !is_ebpf) || (is_bpf && is_ebpf)) return -EINVAL; - ret = tcf_exts_validate(net, tp, tb, est, &prog->exts, ovr, extack); + ret = tcf_exts_validate(net, tp, tb, est, &prog->exts, ovr, true, + extack); if (ret < 0) return ret; @@ -455,7 +456,8 @@ static int cls_bpf_set_parms(struct net *net, struct tcf_proto *tp, static int cls_bpf_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, - void **arg, bool ovr, struct netlink_ext_ack *extack) + void **arg, bool ovr, bool rtnl_held, + struct netlink_ext_ack *extack) { struct cls_bpf_head *head = rtnl_dereference(tp->root); struct cls_bpf_prog *oldprog = *arg; @@ -475,7 +477,7 @@ static int cls_bpf_change(struct net *net, struct sk_buff *in_skb, if (!prog) return -ENOBUFS; - ret = tcf_exts_init(&prog->exts, TCA_BPF_ACT, TCA_BPF_POLICE); + ret = tcf_exts_init(&prog->exts, net, TCA_BPF_ACT, TCA_BPF_POLICE); if (ret < 0) goto errout; @@ -575,7 +577,7 @@ static int cls_bpf_dump_ebpf_info(const struct cls_bpf_prog *prog, } static int cls_bpf_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *tm) + struct sk_buff *skb, struct tcmsg *tm, bool rtnl_held) { struct cls_bpf_prog *prog = fh; struct nlattr *nest; @@ -635,7 +637,8 @@ static void cls_bpf_bind_class(void *fh, u32 classid, unsigned long cl) prog->res.class = cl; } -static void cls_bpf_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void cls_bpf_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct cls_bpf_head *head = rtnl_dereference(tp->root); struct cls_bpf_prog *prog; diff --git a/net/sched/cls_cgroup.c b/net/sched/cls_cgroup.c index 3bc01bdde165..4c1567854f95 100644 --- a/net/sched/cls_cgroup.c +++ b/net/sched/cls_cgroup.c @@ -78,7 +78,7 @@ static void cls_cgroup_destroy_work(struct work_struct *work) static int cls_cgroup_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, - void **arg, bool ovr, + void **arg, bool ovr, bool rtnl_held, struct netlink_ext_ack *extack) { struct nlattr *tb[TCA_CGROUP_MAX + 1]; @@ -99,7 +99,7 @@ static int cls_cgroup_change(struct net *net, struct sk_buff *in_skb, if (!new) return -ENOBUFS; - err = tcf_exts_init(&new->exts, TCA_CGROUP_ACT, TCA_CGROUP_POLICE); + err = tcf_exts_init(&new->exts, net, TCA_CGROUP_ACT, TCA_CGROUP_POLICE); if (err < 0) goto errout; new->handle = handle; @@ -110,7 +110,7 @@ static int cls_cgroup_change(struct net *net, struct sk_buff *in_skb, goto errout; err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &new->exts, ovr, - extack); + true, extack); if (err < 0) goto errout; @@ -130,7 +130,7 @@ errout: return err; } -static void cls_cgroup_destroy(struct tcf_proto *tp, +static void cls_cgroup_destroy(struct tcf_proto *tp, bool rtnl_held, struct netlink_ext_ack *extack) { struct cls_cgroup_head *head = rtnl_dereference(tp->root); @@ -145,18 +145,21 @@ static void cls_cgroup_destroy(struct tcf_proto *tp, } static int cls_cgroup_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { return -EOPNOTSUPP; } -static void cls_cgroup_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void cls_cgroup_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct cls_cgroup_head *head = rtnl_dereference(tp->root); if (arg->count < arg->skip) goto skip; + if (!head) + return; if (arg->fn(tp, head, arg) < 0) { arg->stop = 1; return; @@ -166,7 +169,7 @@ skip: } static int cls_cgroup_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct cls_cgroup_head *head = rtnl_dereference(tp->root); struct nlattr *nest; diff --git a/net/sched/cls_flow.c b/net/sched/cls_flow.c index 2bb043cd436b..eece1ee26930 100644 --- a/net/sched/cls_flow.c +++ b/net/sched/cls_flow.c @@ -391,7 +391,8 @@ static void flow_destroy_filter_work(struct work_struct *work) static int flow_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, - void **arg, bool ovr, struct netlink_ext_ack *extack) + void **arg, bool ovr, bool rtnl_held, + struct netlink_ext_ack *extack) { struct flow_head *head = rtnl_dereference(tp->root); struct flow_filter *fold, *fnew; @@ -440,12 +441,12 @@ static int flow_change(struct net *net, struct sk_buff *in_skb, if (err < 0) goto err1; - err = tcf_exts_init(&fnew->exts, TCA_FLOW_ACT, TCA_FLOW_POLICE); + err = tcf_exts_init(&fnew->exts, net, TCA_FLOW_ACT, TCA_FLOW_POLICE); if (err < 0) goto err2; err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &fnew->exts, ovr, - extack); + true, extack); if (err < 0) goto err2; @@ -566,7 +567,7 @@ err1: } static int flow_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct flow_head *head = rtnl_dereference(tp->root); struct flow_filter *f = arg; @@ -590,7 +591,8 @@ static int flow_init(struct tcf_proto *tp) return 0; } -static void flow_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void flow_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct flow_head *head = rtnl_dereference(tp->root); struct flow_filter *f, *next; @@ -617,7 +619,7 @@ static void *flow_get(struct tcf_proto *tp, u32 handle) } static int flow_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct flow_filter *f = fh; struct nlattr *nest; @@ -677,7 +679,8 @@ nla_put_failure: return -1; } -static void flow_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void flow_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct flow_head *head = rtnl_dereference(tp->root); struct flow_filter *f; diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c index f6aa57fbbbaf..c04247b403ed 100644 --- a/net/sched/cls_flower.c +++ b/net/sched/cls_flower.c @@ -381,16 +381,31 @@ static int fl_hw_replace_filter(struct tcf_proto *tp, bool skip_sw = tc_skip_sw(f->flags); int err; + cls_flower.rule = flow_rule_alloc(tcf_exts_num_actions(&f->exts)); + if (!cls_flower.rule) + return -ENOMEM; + tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack); cls_flower.command = TC_CLSFLOWER_REPLACE; cls_flower.cookie = (unsigned long) f; - cls_flower.dissector = &f->mask->dissector; - cls_flower.mask = &f->mask->key; - cls_flower.key = &f->mkey; - cls_flower.exts = &f->exts; + cls_flower.rule->match.dissector = &f->mask->dissector; + cls_flower.rule->match.mask = &f->mask->key; + cls_flower.rule->match.key = &f->mkey; cls_flower.classid = f->res.classid; + err = tc_setup_flow_action(&cls_flower.rule->action, &f->exts); + if (err) { + kfree(cls_flower.rule); + if (skip_sw) { + NL_SET_ERR_MSG_MOD(extack, "Failed to setup flow action"); + return err; + } + return 0; + } + err = tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, skip_sw); + kfree(cls_flower.rule); + if (err < 0) { fl_hw_destroy_filter(tp, f, NULL); return err; @@ -413,10 +428,13 @@ static void fl_hw_update_stats(struct tcf_proto *tp, struct cls_fl_filter *f) tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, NULL); cls_flower.command = TC_CLSFLOWER_STATS; cls_flower.cookie = (unsigned long) f; - cls_flower.exts = &f->exts; cls_flower.classid = f->res.classid; tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false); + + tcf_exts_stats_update(&f->exts, cls_flower.stats.bytes, + cls_flower.stats.pkts, + cls_flower.stats.lastused); } static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f, @@ -451,7 +469,8 @@ static void fl_destroy_sleepable(struct work_struct *work) module_put(THIS_MODULE); } -static void fl_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void fl_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct cls_fl_head *head = rtnl_dereference(tp->root); struct fl_flow_mask *mask, *next_mask; @@ -1258,7 +1277,8 @@ static int fl_set_parms(struct net *net, struct tcf_proto *tp, { int err; - err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, extack); + err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, true, + extack); if (err < 0) return err; @@ -1285,7 +1305,8 @@ static int fl_set_parms(struct net *net, struct tcf_proto *tp, static int fl_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, - void **arg, bool ovr, struct netlink_ext_ack *extack) + void **arg, bool ovr, bool rtnl_held, + struct netlink_ext_ack *extack) { struct cls_fl_head *head = rtnl_dereference(tp->root); struct cls_fl_filter *fold = *arg; @@ -1323,55 +1344,55 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, goto errout_tb; } - err = tcf_exts_init(&fnew->exts, TCA_FLOWER_ACT, 0); + err = tcf_exts_init(&fnew->exts, net, TCA_FLOWER_ACT, 0); if (err < 0) goto errout; - if (!handle) { - handle = 1; - err = idr_alloc_u32(&head->handle_idr, fnew, &handle, - INT_MAX, GFP_KERNEL); - } else if (!fold) { - /* user specifies a handle and it doesn't exist */ - err = idr_alloc_u32(&head->handle_idr, fnew, &handle, - handle, GFP_KERNEL); - } - if (err) - goto errout; - fnew->handle = handle; - if (tb[TCA_FLOWER_FLAGS]) { fnew->flags = nla_get_u32(tb[TCA_FLOWER_FLAGS]); if (!tc_flags_valid(fnew->flags)) { err = -EINVAL; - goto errout_idr; + goto errout; } } err = fl_set_parms(net, tp, fnew, mask, base, tb, tca[TCA_RATE], ovr, tp->chain->tmplt_priv, extack); if (err) - goto errout_idr; + goto errout; err = fl_check_assign_mask(head, fnew, fold, mask); if (err) - goto errout_idr; + goto errout; + + if (!handle) { + handle = 1; + err = idr_alloc_u32(&head->handle_idr, fnew, &handle, + INT_MAX, GFP_KERNEL); + } else if (!fold) { + /* user specifies a handle and it doesn't exist */ + err = idr_alloc_u32(&head->handle_idr, fnew, &handle, + handle, GFP_KERNEL); + } + if (err) + goto errout_mask; + fnew->handle = handle; if (!fold && __fl_lookup(fnew->mask, &fnew->mkey)) { err = -EEXIST; - goto errout_mask; + goto errout_idr; } err = rhashtable_insert_fast(&fnew->mask->ht, &fnew->ht_node, fnew->mask->filter_ht_params); if (err) - goto errout_mask; + goto errout_idr; if (!tc_skip_hw(fnew->flags)) { err = fl_hw_replace_filter(tp, fnew, extack); if (err) - goto errout_mask; + goto errout_mask_ht; } if (!tc_in_hw(fnew->flags)) @@ -1401,12 +1422,17 @@ static int fl_change(struct net *net, struct sk_buff *in_skb, kfree(mask); return 0; -errout_mask: - fl_mask_put(head, fnew->mask, false); +errout_mask_ht: + rhashtable_remove_fast(&fnew->mask->ht, &fnew->ht_node, + fnew->mask->filter_ht_params); errout_idr: if (!fold) idr_remove(&head->handle_idr, fnew->handle); + +errout_mask: + fl_mask_put(head, fnew->mask, false); + errout: tcf_exts_destroy(&fnew->exts); kfree(fnew); @@ -1418,7 +1444,7 @@ errout_mask_alloc: } static int fl_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct cls_fl_head *head = rtnl_dereference(tp->root); struct cls_fl_filter *f = arg; @@ -1430,7 +1456,8 @@ static int fl_delete(struct tcf_proto *tp, void *arg, bool *last, return 0; } -static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct cls_fl_head *head = rtnl_dereference(tp->root); struct cls_fl_filter *f; @@ -1463,18 +1490,36 @@ static int fl_reoffload(struct tcf_proto *tp, bool add, tc_setup_cb_t *cb, if (tc_skip_hw(f->flags)) continue; + cls_flower.rule = + flow_rule_alloc(tcf_exts_num_actions(&f->exts)); + if (!cls_flower.rule) + return -ENOMEM; + tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack); cls_flower.command = add ? TC_CLSFLOWER_REPLACE : TC_CLSFLOWER_DESTROY; cls_flower.cookie = (unsigned long)f; - cls_flower.dissector = &mask->dissector; - cls_flower.mask = &mask->key; - cls_flower.key = &f->mkey; - cls_flower.exts = &f->exts; + cls_flower.rule->match.dissector = &mask->dissector; + cls_flower.rule->match.mask = &mask->key; + cls_flower.rule->match.key = &f->mkey; + + err = tc_setup_flow_action(&cls_flower.rule->action, + &f->exts); + if (err) { + kfree(cls_flower.rule); + if (tc_skip_sw(f->flags)) { + NL_SET_ERR_MSG_MOD(extack, "Failed to setup flow action"); + return err; + } + continue; + } + cls_flower.classid = f->res.classid; err = cb(TC_SETUP_CLSFLOWER, &cls_flower, cb_priv); + kfree(cls_flower.rule); + if (err) { if (add && tc_skip_sw(f->flags)) return err; @@ -1489,25 +1534,30 @@ static int fl_reoffload(struct tcf_proto *tp, bool add, tc_setup_cb_t *cb, return 0; } -static void fl_hw_create_tmplt(struct tcf_chain *chain, - struct fl_flow_tmplt *tmplt) +static int fl_hw_create_tmplt(struct tcf_chain *chain, + struct fl_flow_tmplt *tmplt) { struct tc_cls_flower_offload cls_flower = {}; struct tcf_block *block = chain->block; - struct tcf_exts dummy_exts = { 0, }; + + cls_flower.rule = flow_rule_alloc(0); + if (!cls_flower.rule) + return -ENOMEM; cls_flower.common.chain_index = chain->index; cls_flower.command = TC_CLSFLOWER_TMPLT_CREATE; cls_flower.cookie = (unsigned long) tmplt; - cls_flower.dissector = &tmplt->dissector; - cls_flower.mask = &tmplt->mask; - cls_flower.key = &tmplt->dummy_key; - cls_flower.exts = &dummy_exts; + cls_flower.rule->match.dissector = &tmplt->dissector; + cls_flower.rule->match.mask = &tmplt->mask; + cls_flower.rule->match.key = &tmplt->dummy_key; /* We don't care if driver (any of them) fails to handle this * call. It serves just as a hint for it. */ tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false); + kfree(cls_flower.rule); + + return 0; } static void fl_hw_destroy_tmplt(struct tcf_chain *chain, @@ -1551,12 +1601,14 @@ static void *fl_tmplt_create(struct net *net, struct tcf_chain *chain, err = fl_set_key(net, tb, &tmplt->dummy_key, &tmplt->mask, extack); if (err) goto errout_tmplt; - kfree(tb); fl_init_dissector(&tmplt->dissector, &tmplt->mask); - fl_hw_create_tmplt(chain, tmplt); + err = fl_hw_create_tmplt(chain, tmplt); + if (err) + goto errout_tmplt; + kfree(tb); return tmplt; errout_tmplt: @@ -2004,7 +2056,7 @@ nla_put_failure: } static int fl_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct cls_fl_filter *f = fh; struct nlattr *nest; diff --git a/net/sched/cls_fw.c b/net/sched/cls_fw.c index 29eeeaf3ea44..ad036b00427d 100644 --- a/net/sched/cls_fw.c +++ b/net/sched/cls_fw.c @@ -139,7 +139,8 @@ static void fw_delete_filter_work(struct work_struct *work) rtnl_unlock(); } -static void fw_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void fw_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct fw_head *head = rtnl_dereference(tp->root); struct fw_filter *f; @@ -163,7 +164,7 @@ static void fw_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) } static int fw_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct fw_head *head = rtnl_dereference(tp->root); struct fw_filter *f = arg; @@ -217,7 +218,7 @@ static int fw_set_parms(struct net *net, struct tcf_proto *tp, int err; err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &f->exts, ovr, - extack); + true, extack); if (err < 0) return err; @@ -250,7 +251,8 @@ static int fw_set_parms(struct net *net, struct tcf_proto *tp, static int fw_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, void **arg, - bool ovr, struct netlink_ext_ack *extack) + bool ovr, bool rtnl_held, + struct netlink_ext_ack *extack) { struct fw_head *head = rtnl_dereference(tp->root); struct fw_filter *f = *arg; @@ -283,7 +285,8 @@ static int fw_change(struct net *net, struct sk_buff *in_skb, #endif /* CONFIG_NET_CLS_IND */ fnew->tp = f->tp; - err = tcf_exts_init(&fnew->exts, TCA_FW_ACT, TCA_FW_POLICE); + err = tcf_exts_init(&fnew->exts, net, TCA_FW_ACT, + TCA_FW_POLICE); if (err < 0) { kfree(fnew); return err; @@ -332,7 +335,7 @@ static int fw_change(struct net *net, struct sk_buff *in_skb, if (f == NULL) return -ENOBUFS; - err = tcf_exts_init(&f->exts, TCA_FW_ACT, TCA_FW_POLICE); + err = tcf_exts_init(&f->exts, net, TCA_FW_ACT, TCA_FW_POLICE); if (err < 0) goto errout; f->id = handle; @@ -354,7 +357,8 @@ errout: return err; } -static void fw_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void fw_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct fw_head *head = rtnl_dereference(tp->root); int h; @@ -384,7 +388,7 @@ static void fw_walk(struct tcf_proto *tp, struct tcf_walker *arg) } static int fw_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct fw_head *head = rtnl_dereference(tp->root); struct fw_filter *f = fh; diff --git a/net/sched/cls_matchall.c b/net/sched/cls_matchall.c index 0e408ee9dcec..459921bd3d87 100644 --- a/net/sched/cls_matchall.c +++ b/net/sched/cls_matchall.c @@ -12,6 +12,7 @@ #include <linux/kernel.h> #include <linux/init.h> #include <linux/module.h> +#include <linux/percpu.h> #include <net/sch_generic.h> #include <net/pkt_cls.h> @@ -22,6 +23,7 @@ struct cls_mall_head { u32 handle; u32 flags; unsigned int in_hw_count; + struct tc_matchall_pcnt __percpu *pf; struct rcu_work rwork; }; @@ -34,6 +36,7 @@ static int mall_classify(struct sk_buff *skb, const struct tcf_proto *tp, return -1; *res = head->res; + __this_cpu_inc(head->pf->rhit); return tcf_exts_exec(skb, &head->exts, res); } @@ -46,6 +49,7 @@ static void __mall_destroy(struct cls_mall_head *head) { tcf_exts_destroy(&head->exts); tcf_exts_put_net(&head->exts); + free_percpu(head->pf); kfree(head); } @@ -105,7 +109,8 @@ static int mall_replace_hw_filter(struct tcf_proto *tp, return 0; } -static void mall_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void mall_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct cls_mall_head *head = rtnl_dereference(tp->root); @@ -141,7 +146,8 @@ static int mall_set_parms(struct net *net, struct tcf_proto *tp, { int err; - err = tcf_exts_validate(net, tp, tb, est, &head->exts, ovr, extack); + err = tcf_exts_validate(net, tp, tb, est, &head->exts, ovr, true, + extack); if (err < 0) return err; @@ -155,7 +161,8 @@ static int mall_set_parms(struct net *net, struct tcf_proto *tp, static int mall_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, - void **arg, bool ovr, struct netlink_ext_ack *extack) + void **arg, bool ovr, bool rtnl_held, + struct netlink_ext_ack *extack) { struct cls_mall_head *head = rtnl_dereference(tp->root); struct nlattr *tb[TCA_MATCHALL_MAX + 1]; @@ -184,7 +191,7 @@ static int mall_change(struct net *net, struct sk_buff *in_skb, if (!new) return -ENOBUFS; - err = tcf_exts_init(&new->exts, TCA_MATCHALL_ACT, 0); + err = tcf_exts_init(&new->exts, net, TCA_MATCHALL_ACT, 0); if (err) goto err_exts_init; @@ -192,6 +199,11 @@ static int mall_change(struct net *net, struct sk_buff *in_skb, handle = 1; new->handle = handle; new->flags = flags; + new->pf = alloc_percpu(struct tc_matchall_pcnt); + if (!new->pf) { + err = -ENOMEM; + goto err_alloc_percpu; + } err = mall_set_parms(net, tp, new, base, tb, tca[TCA_RATE], ovr, extack); @@ -214,6 +226,8 @@ static int mall_change(struct net *net, struct sk_buff *in_skb, err_replace_hw_filter: err_set_parms: + free_percpu(new->pf); +err_alloc_percpu: tcf_exts_destroy(&new->exts); err_exts_init: kfree(new); @@ -221,17 +235,21 @@ err_exts_init: } static int mall_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { return -EOPNOTSUPP; } -static void mall_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void mall_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct cls_mall_head *head = rtnl_dereference(tp->root); if (arg->count < arg->skip) goto skip; + + if (!head) + return; if (arg->fn(tp, head, arg) < 0) arg->stop = 1; skip: @@ -268,10 +286,12 @@ static int mall_reoffload(struct tcf_proto *tp, bool add, tc_setup_cb_t *cb, } static int mall_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { + struct tc_matchall_pcnt gpf = {}; struct cls_mall_head *head = fh; struct nlattr *nest; + int cpu; if (!head) return skb->len; @@ -289,6 +309,17 @@ static int mall_dump(struct net *net, struct tcf_proto *tp, void *fh, if (head->flags && nla_put_u32(skb, TCA_MATCHALL_FLAGS, head->flags)) goto nla_put_failure; + for_each_possible_cpu(cpu) { + struct tc_matchall_pcnt *pf = per_cpu_ptr(head->pf, cpu); + + gpf.rhit += pf->rhit; + } + + if (nla_put_64bit(skb, TCA_MATCHALL_PCNT, + sizeof(struct tc_matchall_pcnt), + &gpf, TCA_MATCHALL_PAD)) + goto nla_put_failure; + if (tcf_exts_dump(skb, &head->exts)) goto nla_put_failure; diff --git a/net/sched/cls_route.c b/net/sched/cls_route.c index 0404aa5fa7cb..f006af23b64a 100644 --- a/net/sched/cls_route.c +++ b/net/sched/cls_route.c @@ -276,7 +276,8 @@ static void route4_queue_work(struct route4_filter *f) tcf_queue_work(&f->rwork, route4_delete_filter_work); } -static void route4_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void route4_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct route4_head *head = rtnl_dereference(tp->root); int h1, h2; @@ -312,7 +313,7 @@ static void route4_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) } static int route4_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct route4_head *head = rtnl_dereference(tp->root); struct route4_filter *f = arg; @@ -393,7 +394,7 @@ static int route4_set_parms(struct net *net, struct tcf_proto *tp, struct route4_bucket *b; int err; - err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, extack); + err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, true, extack); if (err < 0) return err; @@ -468,7 +469,7 @@ static int route4_set_parms(struct net *net, struct tcf_proto *tp, static int route4_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, void **arg, bool ovr, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct route4_head *head = rtnl_dereference(tp->root); struct route4_filter __rcu **fp; @@ -496,7 +497,7 @@ static int route4_change(struct net *net, struct sk_buff *in_skb, if (!f) goto errout; - err = tcf_exts_init(&f->exts, TCA_ROUTE4_ACT, TCA_ROUTE4_POLICE); + err = tcf_exts_init(&f->exts, net, TCA_ROUTE4_ACT, TCA_ROUTE4_POLICE); if (err < 0) goto errout; @@ -560,15 +561,13 @@ errout: return err; } -static void route4_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void route4_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct route4_head *head = rtnl_dereference(tp->root); unsigned int h, h1; - if (head == NULL) - arg->stop = 1; - - if (arg->stop) + if (head == NULL || arg->stop) return; for (h = 0; h <= 256; h++) { @@ -597,7 +596,7 @@ static void route4_walk(struct tcf_proto *tp, struct tcf_walker *arg) } static int route4_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct route4_filter *f = fh; struct nlattr *nest; diff --git a/net/sched/cls_rsvp.h b/net/sched/cls_rsvp.h index e9ccf7daea7d..0719a21d9c41 100644 --- a/net/sched/cls_rsvp.h +++ b/net/sched/cls_rsvp.h @@ -312,7 +312,8 @@ static void rsvp_delete_filter(struct tcf_proto *tp, struct rsvp_filter *f) __rsvp_delete_filter(f); } -static void rsvp_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void rsvp_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct rsvp_head *data = rtnl_dereference(tp->root); int h1, h2; @@ -341,7 +342,7 @@ static void rsvp_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) } static int rsvp_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct rsvp_head *head = rtnl_dereference(tp->root); struct rsvp_filter *nfp, *f = arg; @@ -477,7 +478,8 @@ static int rsvp_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, - void **arg, bool ovr, struct netlink_ext_ack *extack) + void **arg, bool ovr, bool rtnl_held, + struct netlink_ext_ack *extack) { struct rsvp_head *data = rtnl_dereference(tp->root); struct rsvp_filter *f, *nfp; @@ -499,10 +501,11 @@ static int rsvp_change(struct net *net, struct sk_buff *in_skb, if (err < 0) return err; - err = tcf_exts_init(&e, TCA_RSVP_ACT, TCA_RSVP_POLICE); + err = tcf_exts_init(&e, net, TCA_RSVP_ACT, TCA_RSVP_POLICE); if (err < 0) return err; - err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e, ovr, extack); + err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &e, ovr, true, + extack); if (err < 0) goto errout2; @@ -520,7 +523,8 @@ static int rsvp_change(struct net *net, struct sk_buff *in_skb, goto errout2; } - err = tcf_exts_init(&n->exts, TCA_RSVP_ACT, TCA_RSVP_POLICE); + err = tcf_exts_init(&n->exts, net, TCA_RSVP_ACT, + TCA_RSVP_POLICE); if (err < 0) { kfree(n); goto errout2; @@ -548,7 +552,7 @@ static int rsvp_change(struct net *net, struct sk_buff *in_skb, if (f == NULL) goto errout2; - err = tcf_exts_init(&f->exts, TCA_RSVP_ACT, TCA_RSVP_POLICE); + err = tcf_exts_init(&f->exts, net, TCA_RSVP_ACT, TCA_RSVP_POLICE); if (err < 0) goto errout; h2 = 16; @@ -654,7 +658,8 @@ errout2: return err; } -static void rsvp_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void rsvp_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct rsvp_head *head = rtnl_dereference(tp->root); unsigned int h, h1; @@ -688,7 +693,7 @@ static void rsvp_walk(struct tcf_proto *tp, struct tcf_walker *arg) } static int rsvp_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct rsvp_filter *f = fh; struct rsvp_session *s; diff --git a/net/sched/cls_tcindex.c b/net/sched/cls_tcindex.c index 9ccc93f257db..24e0a62a65cc 100644 --- a/net/sched/cls_tcindex.c +++ b/net/sched/cls_tcindex.c @@ -48,7 +48,7 @@ struct tcindex_data { u32 hash; /* hash table size; 0 if undefined */ u32 alloc_hash; /* allocated size */ u32 fall_through; /* 0: only classify if explicit match */ - struct rcu_head rcu; + struct rcu_work rwork; }; static inline int tcindex_filter_is_set(struct tcindex_filter_result *r) @@ -173,7 +173,7 @@ static void tcindex_destroy_fexts_work(struct work_struct *work) } static int tcindex_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct tcindex_data *p = rtnl_dereference(tp->root); struct tcindex_filter_result *r = arg; @@ -221,17 +221,11 @@ found: return 0; } -static int tcindex_destroy_element(struct tcf_proto *tp, - void *arg, struct tcf_walker *walker) -{ - bool last; - - return tcindex_delete(tp, arg, &last, NULL); -} - -static void __tcindex_destroy(struct rcu_head *head) +static void tcindex_destroy_work(struct work_struct *work) { - struct tcindex_data *p = container_of(head, struct tcindex_data, rcu); + struct tcindex_data *p = container_of(to_rcu_work(work), + struct tcindex_data, + rwork); kfree(p->perfect); kfree(p->h); @@ -252,15 +246,19 @@ static const struct nla_policy tcindex_policy[TCA_TCINDEX_MAX + 1] = { [TCA_TCINDEX_CLASSID] = { .type = NLA_U32 }, }; -static int tcindex_filter_result_init(struct tcindex_filter_result *r) +static int tcindex_filter_result_init(struct tcindex_filter_result *r, + struct net *net) { memset(r, 0, sizeof(*r)); - return tcf_exts_init(&r->exts, TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE); + return tcf_exts_init(&r->exts, net, TCA_TCINDEX_ACT, + TCA_TCINDEX_POLICE); } -static void __tcindex_partial_destroy(struct rcu_head *head) +static void tcindex_partial_destroy_work(struct work_struct *work) { - struct tcindex_data *p = container_of(head, struct tcindex_data, rcu); + struct tcindex_data *p = container_of(to_rcu_work(work), + struct tcindex_data, + rwork); kfree(p->perfect); kfree(p); @@ -275,7 +273,7 @@ static void tcindex_free_perfect_hash(struct tcindex_data *cp) kfree(cp->perfect); } -static int tcindex_alloc_perfect_hash(struct tcindex_data *cp) +static int tcindex_alloc_perfect_hash(struct net *net, struct tcindex_data *cp) { int i, err = 0; @@ -285,7 +283,7 @@ static int tcindex_alloc_perfect_hash(struct tcindex_data *cp) return -ENOMEM; for (i = 0; i < cp->hash; i++) { - err = tcf_exts_init(&cp->perfect[i].exts, + err = tcf_exts_init(&cp->perfect[i].exts, net, TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE); if (err < 0) goto errout; @@ -305,16 +303,16 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, struct nlattr *est, bool ovr, struct netlink_ext_ack *extack) { struct tcindex_filter_result new_filter_result, *old_r = r; - struct tcindex_filter_result cr; struct tcindex_data *cp = NULL, *oldp; struct tcindex_filter *f = NULL; /* make gcc behave */ + struct tcf_result cr = {}; int err, balloc = 0; struct tcf_exts e; - err = tcf_exts_init(&e, TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE); + err = tcf_exts_init(&e, net, TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE); if (err < 0) return err; - err = tcf_exts_validate(net, tp, tb, est, &e, ovr, extack); + err = tcf_exts_validate(net, tp, tb, est, &e, ovr, true, extack); if (err < 0) goto errout; @@ -337,7 +335,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, if (p->perfect) { int i; - if (tcindex_alloc_perfect_hash(cp) < 0) + if (tcindex_alloc_perfect_hash(net, cp) < 0) goto errout; for (i = 0; i < cp->hash; i++) cp->perfect[i].res = p->perfect[i].res; @@ -345,14 +343,11 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, } cp->h = p->h; - err = tcindex_filter_result_init(&new_filter_result); - if (err < 0) - goto errout1; - err = tcindex_filter_result_init(&cr); + err = tcindex_filter_result_init(&new_filter_result, net); if (err < 0) goto errout1; if (old_r) - cr.res = r->res; + cr = r->res; if (tb[TCA_TCINDEX_HASH]) cp->hash = nla_get_u32(tb[TCA_TCINDEX_HASH]); @@ -406,7 +401,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, err = -ENOMEM; if (!cp->perfect && !cp->h) { if (valid_perfect_hash(cp)) { - if (tcindex_alloc_perfect_hash(cp) < 0) + if (tcindex_alloc_perfect_hash(net, cp) < 0) goto errout_alloc; balloc = 1; } else { @@ -435,7 +430,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, goto errout_alloc; f->key = handle; f->next = NULL; - err = tcindex_filter_result_init(&f->result); + err = tcindex_filter_result_init(&f->result, net); if (err < 0) { kfree(f); goto errout_alloc; @@ -443,12 +438,12 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, } if (tb[TCA_TCINDEX_CLASSID]) { - cr.res.classid = nla_get_u32(tb[TCA_TCINDEX_CLASSID]); - tcf_bind_filter(tp, &cr.res, base); + cr.classid = nla_get_u32(tb[TCA_TCINDEX_CLASSID]); + tcf_bind_filter(tp, &cr, base); } if (old_r && old_r != r) { - err = tcindex_filter_result_init(old_r); + err = tcindex_filter_result_init(old_r, net); if (err < 0) { kfree(f); goto errout_alloc; @@ -456,7 +451,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, } oldp = p; - r->res = cr.res; + r->res = cr; tcf_exts_change(&r->exts, &e); rcu_assign_pointer(tp->root, cp); @@ -475,10 +470,12 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, ; /* nothing */ rcu_assign_pointer(*fp, f); + } else { + tcf_exts_destroy(&new_filter_result.exts); } if (oldp) - call_rcu(&oldp->rcu, __tcindex_partial_destroy); + tcf_queue_work(&oldp->rwork, tcindex_partial_destroy_work); return 0; errout_alloc: @@ -487,7 +484,6 @@ errout_alloc: else if (balloc == 2) kfree(cp->h); errout1: - tcf_exts_destroy(&cr.exts); tcf_exts_destroy(&new_filter_result.exts); errout: kfree(cp); @@ -499,7 +495,7 @@ static int tcindex_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, struct nlattr **tca, void **arg, bool ovr, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct nlattr *opt = tca[TCA_OPTIONS]; struct nlattr *tb[TCA_TCINDEX_MAX + 1]; @@ -522,7 +518,8 @@ tcindex_change(struct net *net, struct sk_buff *in_skb, tca[TCA_RATE], ovr, extack); } -static void tcindex_walk(struct tcf_proto *tp, struct tcf_walker *walker) +static void tcindex_walk(struct tcf_proto *tp, struct tcf_walker *walker, + bool rtnl_held) { struct tcindex_data *p = rtnl_dereference(tp->root); struct tcindex_filter *f, *next; @@ -558,24 +555,43 @@ static void tcindex_walk(struct tcf_proto *tp, struct tcf_walker *walker) } } -static void tcindex_destroy(struct tcf_proto *tp, +static void tcindex_destroy(struct tcf_proto *tp, bool rtnl_held, struct netlink_ext_ack *extack) { struct tcindex_data *p = rtnl_dereference(tp->root); - struct tcf_walker walker; + int i; pr_debug("tcindex_destroy(tp %p),p %p\n", tp, p); - walker.count = 0; - walker.skip = 0; - walker.fn = tcindex_destroy_element; - tcindex_walk(tp, &walker); - call_rcu(&p->rcu, __tcindex_destroy); + if (p->perfect) { + for (i = 0; i < p->hash; i++) { + struct tcindex_filter_result *r = p->perfect + i; + + tcf_unbind_filter(tp, &r->res); + if (tcf_exts_get_net(&r->exts)) + tcf_queue_work(&r->rwork, + tcindex_destroy_rexts_work); + else + __tcindex_destroy_rexts(r); + } + } + + for (i = 0; p->h && i < p->hash; i++) { + struct tcindex_filter *f, *next; + bool last; + + for (f = rtnl_dereference(p->h[i]); f; f = next) { + next = rtnl_dereference(f->next); + tcindex_delete(tp, &f->result, &last, rtnl_held, NULL); + } + } + + tcf_queue_work(&p->rwork, tcindex_destroy_work); } static int tcindex_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct tcindex_data *p = rtnl_dereference(tp->root); struct tcindex_filter_result *r = fh; diff --git a/net/sched/cls_u32.c b/net/sched/cls_u32.c index dcea21004604..48e76a3acf8a 100644 --- a/net/sched/cls_u32.c +++ b/net/sched/cls_u32.c @@ -629,7 +629,8 @@ static int u32_destroy_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, return -ENOENT; } -static void u32_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) +static void u32_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) { struct tc_u_common *tp_c = tp->data; struct tc_u_hnode *root_ht = rtnl_dereference(tp->root); @@ -663,7 +664,7 @@ static void u32_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) } static int u32_delete(struct tcf_proto *tp, void *arg, bool *last, - struct netlink_ext_ack *extack) + bool rtnl_held, struct netlink_ext_ack *extack) { struct tc_u_hnode *ht = arg; struct tc_u_common *tp_c = tp->data; @@ -726,7 +727,7 @@ static int u32_set_parms(struct net *net, struct tcf_proto *tp, { int err; - err = tcf_exts_validate(net, tp, tb, est, &n->exts, ovr, extack); + err = tcf_exts_validate(net, tp, tb, est, &n->exts, ovr, true, extack); if (err < 0) return err; @@ -803,7 +804,7 @@ static void u32_replace_knode(struct tcf_proto *tp, struct tc_u_common *tp_c, rcu_assign_pointer(*ins, n); } -static struct tc_u_knode *u32_init_knode(struct tcf_proto *tp, +static struct tc_u_knode *u32_init_knode(struct net *net, struct tcf_proto *tp, struct tc_u_knode *n) { struct tc_u_hnode *ht = rtnl_dereference(n->ht_down); @@ -848,7 +849,7 @@ static struct tc_u_knode *u32_init_knode(struct tcf_proto *tp, #endif memcpy(&new->sel, s, sizeof(*s) + s->nkeys*sizeof(struct tc_u32_key)); - if (tcf_exts_init(&new->exts, TCA_U32_ACT, TCA_U32_POLICE)) { + if (tcf_exts_init(&new->exts, net, TCA_U32_ACT, TCA_U32_POLICE)) { kfree(new); return NULL; } @@ -858,7 +859,7 @@ static struct tc_u_knode *u32_init_knode(struct tcf_proto *tp, static int u32_change(struct net *net, struct sk_buff *in_skb, struct tcf_proto *tp, unsigned long base, u32 handle, - struct nlattr **tca, void **arg, bool ovr, + struct nlattr **tca, void **arg, bool ovr, bool rtnl_held, struct netlink_ext_ack *extack) { struct tc_u_common *tp_c = tp->data; @@ -910,7 +911,7 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, return -EINVAL; } - new = u32_init_knode(tp, n); + new = u32_init_knode(net, tp, n); if (!new) return -ENOMEM; @@ -1060,7 +1061,7 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, n->fshift = s->hmask ? ffs(ntohl(s->hmask)) - 1 : 0; n->flags = flags; - err = tcf_exts_init(&n->exts, TCA_U32_ACT, TCA_U32_POLICE); + err = tcf_exts_init(&n->exts, net, TCA_U32_ACT, TCA_U32_POLICE); if (err < 0) goto errout; @@ -1123,7 +1124,8 @@ erridr: return err; } -static void u32_walk(struct tcf_proto *tp, struct tcf_walker *arg) +static void u32_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) { struct tc_u_common *tp_c = tp->data; struct tc_u_hnode *ht; @@ -1281,7 +1283,7 @@ static void u32_bind_class(void *fh, u32 classid, unsigned long cl) } static int u32_dump(struct net *net, struct tcf_proto *tp, void *fh, - struct sk_buff *skb, struct tcmsg *t) + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) { struct tc_u_knode *n = fh; struct tc_u_hnode *ht_up, *ht_down; diff --git a/net/sched/sch_api.c b/net/sched/sch_api.c index 7e4d1ccf4c87..fb8f138b9776 100644 --- a/net/sched/sch_api.c +++ b/net/sched/sch_api.c @@ -526,11 +526,6 @@ static struct qdisc_size_table *qdisc_get_stab(struct nlattr *opt, return stab; } -static void stab_kfree_rcu(struct rcu_head *head) -{ - kfree(container_of(head, struct qdisc_size_table, rcu)); -} - void qdisc_put_stab(struct qdisc_size_table *tab) { if (!tab) @@ -538,7 +533,7 @@ void qdisc_put_stab(struct qdisc_size_table *tab) if (--tab->refcnt == 0) { list_del(&tab->list); - call_rcu(&tab->rcu, stab_kfree_rcu); + kfree_rcu(tab, rcu); } } EXPORT_SYMBOL(qdisc_put_stab); @@ -758,8 +753,7 @@ static u32 qdisc_alloc_handle(struct net_device *dev) return 0; } -void qdisc_tree_reduce_backlog(struct Qdisc *sch, unsigned int n, - unsigned int len) +void qdisc_tree_reduce_backlog(struct Qdisc *sch, int n, int len) { bool qdisc_is_offloaded = sch->flags & TCQ_F_OFFLOADED; const struct Qdisc_class_ops *cops; @@ -1202,9 +1196,11 @@ static struct Qdisc *qdisc_create(struct net_device *dev, } else { if (handle == 0) { handle = qdisc_alloc_handle(dev); - err = -ENOMEM; - if (handle == 0) + if (handle == 0) { + NL_SET_ERR_MSG(extack, "Maximum number of qdisc handles was exceeded"); + err = -ENOSPC; goto err_out3; + } } if (!netif_is_multiqueue(dev)) sch->flags |= TCQ_F_ONETXQUEUE; @@ -1828,6 +1824,7 @@ static int tclass_notify(struct net *net, struct sk_buff *oskb, { struct sk_buff *skb; u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + int err = 0; skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); if (!skb) @@ -1838,8 +1835,11 @@ static int tclass_notify(struct net *net, struct sk_buff *oskb, return -EINVAL; } - return rtnetlink_send(skb, net, portid, RTNLGRP_TC, - n->nlmsg_flags & NLM_F_ECHO); + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + if (err > 0) + err = 0; + return err; } static int tclass_del_notify(struct net *net, @@ -1870,8 +1870,11 @@ static int tclass_del_notify(struct net *net, return err; } - return rtnetlink_send(skb, net, portid, RTNLGRP_TC, - n->nlmsg_flags & NLM_F_ECHO); + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + if (err > 0) + err = 0; + return err; } #ifdef CONFIG_NET_CLS @@ -1910,17 +1913,19 @@ static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid, block = cops->tcf_block(q, cl, NULL); if (!block) return; - list_for_each_entry(chain, &block->chain_list, list) { + for (chain = tcf_get_next_chain(block, NULL); + chain; + chain = tcf_get_next_chain(block, chain)) { struct tcf_proto *tp; - for (tp = rtnl_dereference(chain->filter_chain); - tp; tp = rtnl_dereference(tp->next)) { + for (tp = tcf_get_next_proto(chain, NULL, true); + tp; tp = tcf_get_next_proto(chain, tp, true)) { struct tcf_bind_args arg = {}; arg.w.fn = tcf_node_bind; arg.classid = clid; arg.cl = new_cl; - tp->ops->walk(tp, &arg.w); + tp->ops->walk(tp, &arg.w, true); } } } diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c index 73940293700d..1d2a12132abc 100644 --- a/net/sched/sch_cake.c +++ b/net/sched/sch_cake.c @@ -138,8 +138,8 @@ struct cake_flow { struct cake_host { u32 srchost_tag; u32 dsthost_tag; - u16 srchost_refcnt; - u16 dsthost_refcnt; + u16 srchost_bulk_flow_count; + u16 dsthost_bulk_flow_count; }; struct cake_heap_entry { @@ -258,7 +258,8 @@ enum { CAKE_FLAG_AUTORATE_INGRESS = BIT(1), CAKE_FLAG_INGRESS = BIT(2), CAKE_FLAG_WASH = BIT(3), - CAKE_FLAG_SPLIT_GSO = BIT(4) + CAKE_FLAG_SPLIT_GSO = BIT(4), + CAKE_FLAG_FWMARK = BIT(5) }; /* COBALT operates the Codel and BLUE algorithms in parallel, in order to @@ -746,8 +747,10 @@ skip_hash: * queue, accept the collision, update the host tags. */ q->way_collisions++; - q->hosts[q->flows[reduced_hash].srchost].srchost_refcnt--; - q->hosts[q->flows[reduced_hash].dsthost].dsthost_refcnt--; + if (q->flows[outer_hash + k].set == CAKE_SET_BULK) { + q->hosts[q->flows[reduced_hash].srchost].srchost_bulk_flow_count--; + q->hosts[q->flows[reduced_hash].dsthost].dsthost_bulk_flow_count--; + } allocate_src = cake_dsrc(flow_mode); allocate_dst = cake_ddst(flow_mode); found: @@ -767,13 +770,14 @@ found: } for (i = 0; i < CAKE_SET_WAYS; i++, k = (k + 1) % CAKE_SET_WAYS) { - if (!q->hosts[outer_hash + k].srchost_refcnt) + if (!q->hosts[outer_hash + k].srchost_bulk_flow_count) break; } q->hosts[outer_hash + k].srchost_tag = srchost_hash; found_src: srchost_idx = outer_hash + k; - q->hosts[srchost_idx].srchost_refcnt++; + if (q->flows[reduced_hash].set == CAKE_SET_BULK) + q->hosts[srchost_idx].srchost_bulk_flow_count++; q->flows[reduced_hash].srchost = srchost_idx; } @@ -789,13 +793,14 @@ found_src: } for (i = 0; i < CAKE_SET_WAYS; i++, k = (k + 1) % CAKE_SET_WAYS) { - if (!q->hosts[outer_hash + k].dsthost_refcnt) + if (!q->hosts[outer_hash + k].dsthost_bulk_flow_count) break; } q->hosts[outer_hash + k].dsthost_tag = dsthost_hash; found_dst: dsthost_idx = outer_hash + k; - q->hosts[dsthost_idx].dsthost_refcnt++; + if (q->flows[reduced_hash].set == CAKE_SET_BULK) + q->hosts[dsthost_idx].dsthost_bulk_flow_count++; q->flows[reduced_hash].dsthost = dsthost_idx; } } @@ -1508,20 +1513,6 @@ static unsigned int cake_drop(struct Qdisc *sch, struct sk_buff **to_free) return idx + (tin << 16); } -static void cake_wash_diffserv(struct sk_buff *skb) -{ - switch (skb->protocol) { - case htons(ETH_P_IP): - ipv4_change_dsfield(ip_hdr(skb), INET_ECN_MASK, 0); - break; - case htons(ETH_P_IPV6): - ipv6_change_dsfield(ipv6_hdr(skb), INET_ECN_MASK, 0); - break; - default: - break; - } -} - static u8 cake_handle_diffserv(struct sk_buff *skb, u16 wash) { u8 dscp; @@ -1553,25 +1544,32 @@ static struct cake_tin_data *cake_select_tin(struct Qdisc *sch, { struct cake_sched_data *q = qdisc_priv(sch); u32 tin; + u8 dscp; + + /* Tin selection: Default to diffserv-based selection, allow overriding + * using firewall marks or skb->priority. + */ + dscp = cake_handle_diffserv(skb, + q->rate_flags & CAKE_FLAG_WASH); + + if (q->tin_mode == CAKE_DIFFSERV_BESTEFFORT) + tin = 0; - if (TC_H_MAJ(skb->priority) == sch->handle && - TC_H_MIN(skb->priority) > 0 && - TC_H_MIN(skb->priority) <= q->tin_cnt) { + else if (q->rate_flags & CAKE_FLAG_FWMARK && /* use fw mark */ + skb->mark && + skb->mark <= q->tin_cnt) + tin = q->tin_order[skb->mark - 1]; + + else if (TC_H_MAJ(skb->priority) == sch->handle && + TC_H_MIN(skb->priority) > 0 && + TC_H_MIN(skb->priority) <= q->tin_cnt) tin = q->tin_order[TC_H_MIN(skb->priority) - 1]; - if (q->rate_flags & CAKE_FLAG_WASH) - cake_wash_diffserv(skb); - } else if (q->tin_mode != CAKE_DIFFSERV_BESTEFFORT) { - /* extract the Diffserv Precedence field, if it exists */ - /* and clear DSCP bits if washing */ - tin = q->tin_index[cake_handle_diffserv(skb, - q->rate_flags & CAKE_FLAG_WASH)]; + else { + tin = q->tin_index[dscp]; + if (unlikely(tin >= q->tin_cnt)) tin = 0; - } else { - tin = 0; - if (q->rate_flags & CAKE_FLAG_WASH) - cake_wash_diffserv(skb); } return &q->tins[tin]; @@ -1794,20 +1792,30 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch, b->sparse_flow_count++; if (cake_dsrc(q->flow_mode)) - host_load = max(host_load, srchost->srchost_refcnt); + host_load = max(host_load, srchost->srchost_bulk_flow_count); if (cake_ddst(q->flow_mode)) - host_load = max(host_load, dsthost->dsthost_refcnt); + host_load = max(host_load, dsthost->dsthost_bulk_flow_count); flow->deficit = (b->flow_quantum * quantum_div[host_load]) >> 16; } else if (flow->set == CAKE_SET_SPARSE_WAIT) { + struct cake_host *srchost = &b->hosts[flow->srchost]; + struct cake_host *dsthost = &b->hosts[flow->dsthost]; + /* this flow was empty, accounted as a sparse flow, but actually * in the bulk rotation. */ flow->set = CAKE_SET_BULK; b->sparse_flow_count--; b->bulk_flow_count++; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count++; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count++; + } if (q->buffer_used > q->buffer_max_used) @@ -1975,23 +1983,8 @@ retry: dsthost = &b->hosts[flow->dsthost]; host_load = 1; - if (cake_dsrc(q->flow_mode)) - host_load = max(host_load, srchost->srchost_refcnt); - - if (cake_ddst(q->flow_mode)) - host_load = max(host_load, dsthost->dsthost_refcnt); - - WARN_ON(host_load > CAKE_QUEUES); - /* flow isolation (DRR++) */ if (flow->deficit <= 0) { - /* The shifted prandom_u32() is a way to apply dithering to - * avoid accumulating roundoff errors - */ - flow->deficit += (b->flow_quantum * quantum_div[host_load] + - (prandom_u32() >> 16)) >> 16; - list_move_tail(&flow->flowchain, &b->old_flows); - /* Keep all flows with deficits out of the sparse and decaying * rotations. No non-empty flow can go into the decaying * rotation, so they can't get deficits @@ -2000,6 +1993,13 @@ retry: if (flow->head) { b->sparse_flow_count--; b->bulk_flow_count++; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count++; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count++; + flow->set = CAKE_SET_BULK; } else { /* we've moved it to the bulk rotation for @@ -2009,6 +2009,22 @@ retry: flow->set = CAKE_SET_SPARSE_WAIT; } } + + if (cake_dsrc(q->flow_mode)) + host_load = max(host_load, srchost->srchost_bulk_flow_count); + + if (cake_ddst(q->flow_mode)) + host_load = max(host_load, dsthost->dsthost_bulk_flow_count); + + WARN_ON(host_load > CAKE_QUEUES); + + /* The shifted prandom_u32() is a way to apply dithering to + * avoid accumulating roundoff errors + */ + flow->deficit += (b->flow_quantum * quantum_div[host_load] + + (prandom_u32() >> 16)) >> 16; + list_move_tail(&flow->flowchain, &b->old_flows); + goto retry; } @@ -2029,6 +2045,13 @@ retry: &b->decaying_flows); if (flow->set == CAKE_SET_BULK) { b->bulk_flow_count--; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count--; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count--; + b->decaying_flow_count++; } else if (flow->set == CAKE_SET_SPARSE || flow->set == CAKE_SET_SPARSE_WAIT) { @@ -2042,14 +2065,19 @@ retry: if (flow->set == CAKE_SET_SPARSE || flow->set == CAKE_SET_SPARSE_WAIT) b->sparse_flow_count--; - else if (flow->set == CAKE_SET_BULK) + else if (flow->set == CAKE_SET_BULK) { b->bulk_flow_count--; - else + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count--; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count--; + + } else b->decaying_flow_count--; flow->set = CAKE_SET_NONE; - srchost->srchost_refcnt--; - dsthost->dsthost_refcnt--; } goto begin; } @@ -2590,6 +2618,13 @@ static int cake_change(struct Qdisc *sch, struct nlattr *opt, q->rate_flags &= ~CAKE_FLAG_SPLIT_GSO; } + if (tb[TCA_CAKE_FWMARK]) { + if (!!nla_get_u32(tb[TCA_CAKE_FWMARK])) + q->rate_flags |= CAKE_FLAG_FWMARK; + else + q->rate_flags &= ~CAKE_FLAG_FWMARK; + } + if (q->tins) { sch_tree_lock(sch); cake_reconfigure(sch); @@ -2749,6 +2784,10 @@ static int cake_dump(struct Qdisc *sch, struct sk_buff *skb) !!(q->rate_flags & CAKE_FLAG_SPLIT_GSO))) goto nla_put_failure; + if (nla_put_u32(skb, TCA_CAKE_FWMARK, + !!(q->rate_flags & CAKE_FLAG_FWMARK))) + goto nla_put_failure; + return nla_nest_end(skb, opts); nla_put_failure: diff --git a/net/sched/sch_generic.c b/net/sched/sch_generic.c index 66ba2ce2320f..a117d9260558 100644 --- a/net/sched/sch_generic.c +++ b/net/sched/sch_generic.c @@ -68,7 +68,7 @@ static inline struct sk_buff *__skb_dequeue_bad_txq(struct Qdisc *q) skb = __skb_dequeue(&q->skb_bad_txq); if (qdisc_is_percpu_stats(q)) { qdisc_qstats_cpu_backlog_dec(q, skb); - qdisc_qstats_cpu_qlen_dec(q); + qdisc_qstats_atomic_qlen_dec(q); } else { qdisc_qstats_backlog_dec(q, skb); q->q.qlen--; @@ -108,7 +108,7 @@ static inline void qdisc_enqueue_skb_bad_txq(struct Qdisc *q, if (qdisc_is_percpu_stats(q)) { qdisc_qstats_cpu_backlog_inc(q, skb); - qdisc_qstats_cpu_qlen_inc(q); + qdisc_qstats_atomic_qlen_inc(q); } else { qdisc_qstats_backlog_inc(q, skb); q->q.qlen++; @@ -147,7 +147,7 @@ static inline int dev_requeue_skb_locked(struct sk_buff *skb, struct Qdisc *q) qdisc_qstats_cpu_requeues_inc(q); qdisc_qstats_cpu_backlog_inc(q, skb); - qdisc_qstats_cpu_qlen_inc(q); + qdisc_qstats_atomic_qlen_inc(q); skb = next; } @@ -252,7 +252,7 @@ static struct sk_buff *dequeue_skb(struct Qdisc *q, bool *validate, skb = __skb_dequeue(&q->gso_skb); if (qdisc_is_percpu_stats(q)) { qdisc_qstats_cpu_backlog_dec(q, skb); - qdisc_qstats_cpu_qlen_dec(q); + qdisc_qstats_atomic_qlen_dec(q); } else { qdisc_qstats_backlog_dec(q, skb); q->q.qlen--; @@ -500,7 +500,7 @@ static void dev_watchdog_down(struct net_device *dev) * netif_carrier_on - set carrier * @dev: network device * - * Device has detected that carrier. + * Device has detected acquisition of carrier. */ void netif_carrier_on(struct net_device *dev) { @@ -559,7 +559,7 @@ struct Qdisc_ops noop_qdisc_ops __read_mostly = { }; static struct netdev_queue noop_netdev_queue = { - .qdisc = &noop_qdisc, + RCU_POINTER_INITIALIZER(qdisc, &noop_qdisc), .qdisc_sleeping = &noop_qdisc, }; @@ -645,7 +645,7 @@ static int pfifo_fast_enqueue(struct sk_buff *skb, struct Qdisc *qdisc, if (unlikely(err)) return qdisc_drop_cpu(skb, qdisc, to_free); - qdisc_qstats_cpu_qlen_inc(qdisc); + qdisc_qstats_atomic_qlen_inc(qdisc); /* Note: skb can not be used after skb_array_produce(), * so we better not use qdisc_qstats_cpu_backlog_inc() */ @@ -670,7 +670,7 @@ static struct sk_buff *pfifo_fast_dequeue(struct Qdisc *qdisc) if (likely(skb)) { qdisc_qstats_cpu_backlog_dec(qdisc, skb); qdisc_bstats_cpu_update(qdisc, skb); - qdisc_qstats_cpu_qlen_dec(qdisc); + qdisc_qstats_atomic_qlen_dec(qdisc); } return skb; @@ -714,7 +714,6 @@ static void pfifo_fast_reset(struct Qdisc *qdisc) struct gnet_stats_queue *q = per_cpu_ptr(qdisc->cpu_qstats, i); q->backlog = 0; - q->qlen = 0; } } @@ -1366,7 +1365,11 @@ static void mini_qdisc_rcu_func(struct rcu_head *head) void mini_qdisc_pair_swap(struct mini_Qdisc_pair *miniqp, struct tcf_proto *tp_head) { - struct mini_Qdisc *miniq_old = rtnl_dereference(*miniqp->p_miniq); + /* Protected with chain0->filter_chain_lock. + * Can't access chain directly because tp_head can be NULL. + */ + struct mini_Qdisc *miniq_old = + rcu_dereference_protected(*miniqp->p_miniq, 1); struct mini_Qdisc *miniq; if (!tp_head) { diff --git a/net/sched/sch_netem.c b/net/sched/sch_netem.c index 75046ec72144..cc9d8133afcd 100644 --- a/net/sched/sch_netem.c +++ b/net/sched/sch_netem.c @@ -447,6 +447,7 @@ static int netem_enqueue(struct sk_buff *skb, struct Qdisc *sch, int nb = 0; int count = 1; int rc = NET_XMIT_SUCCESS; + int rc_drop = NET_XMIT_DROP; /* Do not fool qdisc_drop_all() */ skb->prev = NULL; @@ -486,6 +487,7 @@ static int netem_enqueue(struct sk_buff *skb, struct Qdisc *sch, q->duplicate = 0; rootq->enqueue(skb2, rootq, to_free); q->duplicate = dupsave; + rc_drop = NET_XMIT_SUCCESS; } /* @@ -498,7 +500,7 @@ static int netem_enqueue(struct sk_buff *skb, struct Qdisc *sch, if (skb_is_gso(skb)) { segs = netem_segment(skb, sch, to_free); if (!segs) - return NET_XMIT_DROP; + return rc_drop; } else { segs = skb; } @@ -521,8 +523,10 @@ static int netem_enqueue(struct sk_buff *skb, struct Qdisc *sch, 1<<(prandom_u32() % 8); } - if (unlikely(sch->q.qlen >= sch->limit)) - return qdisc_drop_all(skb, sch, to_free); + if (unlikely(sch->q.qlen >= sch->limit)) { + qdisc_drop_all(skb, sch, to_free); + return rc_drop; + } qdisc_qstats_backlog_inc(sch, skb); diff --git a/net/sched/sch_pie.c b/net/sched/sch_pie.c index d1429371592f..1cc0c7b74aa3 100644 --- a/net/sched/sch_pie.c +++ b/net/sched/sch_pie.c @@ -17,9 +17,7 @@ * University of Oslo, Norway. * * References: - * IETF draft submission: http://tools.ietf.org/html/draft-pan-aqm-pie-00 - * IEEE Conference on High Performance Switching and Routing 2013 : - * "PIE: A * Lightweight Control Scheme to Address the Bufferbloat Problem" + * RFC 8033: https://tools.ietf.org/html/rfc8033 */ #include <linux/module.h> @@ -31,9 +29,9 @@ #include <net/pkt_sched.h> #include <net/inet_ecn.h> -#define QUEUE_THRESHOLD 10000 +#define QUEUE_THRESHOLD 16384 #define DQCOUNT_INVALID -1 -#define MAX_PROB 0xffffffff +#define MAX_PROB 0xffffffffffffffff #define PIE_SCALE 8 /* parameters used */ @@ -49,14 +47,16 @@ struct pie_params { /* variables used */ struct pie_vars { - u32 prob; /* probability but scaled by u32 limit. */ + u64 prob; /* probability but scaled by u64 limit. */ psched_time_t burst_time; psched_time_t qdelay; psched_time_t qdelay_old; u64 dq_count; /* measured in bytes */ psched_time_t dq_tstamp; /* drain rate */ + u64 accu_prob; /* accumulated drop probability */ u32 avg_dq_rate; /* bytes per pschedtime tick,scaled */ u32 qlen_old; /* in bytes */ + u8 accu_prob_overflows; /* overflows of accu_prob */ }; /* statistics gathering */ @@ -81,9 +81,9 @@ static void pie_params_init(struct pie_params *params) { params->alpha = 2; params->beta = 20; - params->tupdate = usecs_to_jiffies(30 * USEC_PER_MSEC); /* 30 ms */ + params->tupdate = usecs_to_jiffies(15 * USEC_PER_MSEC); /* 15 ms */ params->limit = 1000; /* default of 1000 packets */ - params->target = PSCHED_NS2TICKS(20 * NSEC_PER_MSEC); /* 20 ms */ + params->target = PSCHED_NS2TICKS(15 * NSEC_PER_MSEC); /* 15 ms */ params->ecn = false; params->bytemode = false; } @@ -91,16 +91,18 @@ static void pie_params_init(struct pie_params *params) static void pie_vars_init(struct pie_vars *vars) { vars->dq_count = DQCOUNT_INVALID; + vars->accu_prob = 0; vars->avg_dq_rate = 0; - /* default of 100 ms in pschedtime */ - vars->burst_time = PSCHED_NS2TICKS(100 * NSEC_PER_MSEC); + /* default of 150 ms in pschedtime */ + vars->burst_time = PSCHED_NS2TICKS(150 * NSEC_PER_MSEC); + vars->accu_prob_overflows = 0; } static bool drop_early(struct Qdisc *sch, u32 packet_size) { struct pie_sched_data *q = qdisc_priv(sch); - u32 rnd; - u32 local_prob = q->vars.prob; + u64 rnd; + u64 local_prob = q->vars.prob; u32 mtu = psched_mtu(qdisc_dev(sch)); /* If there is still burst allowance left skip random early drop */ @@ -124,13 +126,33 @@ static bool drop_early(struct Qdisc *sch, u32 packet_size) * probablity. Smaller packets will have lower drop prob in this case */ if (q->params.bytemode && packet_size <= mtu) - local_prob = (local_prob / mtu) * packet_size; + local_prob = (u64)packet_size * div_u64(local_prob, mtu); else local_prob = q->vars.prob; - rnd = prandom_u32(); - if (rnd < local_prob) + if (local_prob == 0) { + q->vars.accu_prob = 0; + q->vars.accu_prob_overflows = 0; + } + + if (local_prob > MAX_PROB - q->vars.accu_prob) + q->vars.accu_prob_overflows++; + + q->vars.accu_prob += local_prob; + + if (q->vars.accu_prob_overflows == 0 && + q->vars.accu_prob < (MAX_PROB / 100) * 85) + return false; + if (q->vars.accu_prob_overflows == 8 && + q->vars.accu_prob >= MAX_PROB / 2) + return true; + + prandom_bytes(&rnd, 8); + if (rnd < local_prob) { + q->vars.accu_prob = 0; + q->vars.accu_prob_overflows = 0; return true; + } return false; } @@ -168,6 +190,8 @@ static int pie_qdisc_enqueue(struct sk_buff *skb, struct Qdisc *sch, out: q->stats.dropped++; + q->vars.accu_prob = 0; + q->vars.accu_prob_overflows = 0; return qdisc_drop(skb, sch, to_free); } @@ -317,9 +341,10 @@ static void calculate_probability(struct Qdisc *sch) u32 qlen = sch->qstats.backlog; /* queue size in bytes */ psched_time_t qdelay = 0; /* in pschedtime */ psched_time_t qdelay_old = q->vars.qdelay; /* in pschedtime */ - s32 delta = 0; /* determines the change in probability */ - u32 oldprob; - u32 alpha, beta; + s64 delta = 0; /* determines the change in probability */ + u64 oldprob; + u64 alpha, beta; + u32 power; bool update_prob = true; q->vars.qdelay_old = q->vars.qdelay; @@ -339,38 +364,36 @@ static void calculate_probability(struct Qdisc *sch) * value for alpha as 0.125. In this implementation, we use values 0-32 * passed from user space to represent this. Also, alpha and beta have * unit of HZ and need to be scaled before they can used to update - * probability. alpha/beta are updated locally below by 1) scaling them - * appropriately 2) scaling down by 16 to come to 0-2 range. - * Please see paper for details. - * - * We scale alpha and beta differently depending on whether we are in - * light, medium or high dropping mode. + * probability. alpha/beta are updated locally below by scaling down + * by 16 to come to 0-2 range. */ - if (q->vars.prob < MAX_PROB / 100) { - alpha = - (q->params.alpha * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 7; - beta = - (q->params.beta * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 7; - } else if (q->vars.prob < MAX_PROB / 10) { - alpha = - (q->params.alpha * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 5; - beta = - (q->params.beta * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 5; - } else { - alpha = - (q->params.alpha * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 4; - beta = - (q->params.beta * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 4; + alpha = ((u64)q->params.alpha * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 4; + beta = ((u64)q->params.beta * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 4; + + /* We scale alpha and beta differently depending on how heavy the + * congestion is. Please see RFC 8033 for details. + */ + if (q->vars.prob < MAX_PROB / 10) { + alpha >>= 1; + beta >>= 1; + + power = 100; + while (q->vars.prob < div_u64(MAX_PROB, power) && + power <= 1000000) { + alpha >>= 2; + beta >>= 2; + power *= 10; + } } /* alpha and beta should be between 0 and 32, in multiples of 1/16 */ - delta += alpha * ((qdelay - q->params.target)); - delta += beta * ((qdelay - qdelay_old)); + delta += alpha * (u64)(qdelay - q->params.target); + delta += beta * (u64)(qdelay - qdelay_old); oldprob = q->vars.prob; /* to ensure we increase probability in steps of no more than 2% */ - if (delta > (s32)(MAX_PROB / (100 / 2)) && + if (delta > (s64)(MAX_PROB / (100 / 2)) && q->vars.prob >= MAX_PROB / 10) delta = (MAX_PROB / 100) * 2; @@ -406,7 +429,8 @@ static void calculate_probability(struct Qdisc *sch) */ if (qdelay == 0 && qdelay_old == 0 && update_prob) - q->vars.prob = (q->vars.prob * 98) / 100; + /* Reduce drop probability to 98.4% */ + q->vars.prob -= q->vars.prob / 64u; q->vars.qdelay = qdelay; q->vars.qlen_old = qlen; diff --git a/net/sctp/associola.c b/net/sctp/associola.c index 201c888604e4..d2c7d0d2abc1 100644 --- a/net/sctp/associola.c +++ b/net/sctp/associola.c @@ -101,7 +101,7 @@ static struct sctp_association *sctp_association_init( * socket values. */ asoc->max_retrans = sp->assocparams.sasoc_asocmaxrxt; - asoc->pf_retrans = net->sctp.pf_retrans; + asoc->pf_retrans = sp->pf_retrans; asoc->rto_initial = msecs_to_jiffies(sp->rtoinfo.srto_initial); asoc->rto_max = msecs_to_jiffies(sp->rtoinfo.srto_max); @@ -1651,8 +1651,11 @@ int sctp_assoc_set_id(struct sctp_association *asoc, gfp_t gfp) if (preload) idr_preload(gfp); spin_lock_bh(&sctp_assocs_id_lock); - /* 0 is not a valid assoc_id, must be >= 1 */ - ret = idr_alloc_cyclic(&sctp_assocs_id, asoc, 1, 0, GFP_NOWAIT); + /* 0, 1, 2 are used as SCTP_FUTURE_ASSOC, SCTP_CURRENT_ASSOC and + * SCTP_ALL_ASSOC, so an available id must be > SCTP_ALL_ASSOC. + */ + ret = idr_alloc_cyclic(&sctp_assocs_id, asoc, SCTP_ALL_ASSOC + 1, 0, + GFP_NOWAIT); spin_unlock_bh(&sctp_assocs_id_lock); if (preload) idr_preload_end(); diff --git a/net/sctp/auth.c b/net/sctp/auth.c index 5b537613946f..39d72e58b8e5 100644 --- a/net/sctp/auth.c +++ b/net/sctp/auth.c @@ -471,12 +471,6 @@ int sctp_auth_init_hmacs(struct sctp_endpoint *ep, gfp_t gfp) struct crypto_shash *tfm = NULL; __u16 id; - /* If AUTH extension is disabled, we are done */ - if (!ep->auth_enable) { - ep->auth_hmacs = NULL; - return 0; - } - /* If the transforms are already allocated, we are done */ if (ep->auth_hmacs) return 0; diff --git a/net/sctp/chunk.c b/net/sctp/chunk.c index 64bef313d436..5cb7c1ff97e9 100644 --- a/net/sctp/chunk.c +++ b/net/sctp/chunk.c @@ -192,7 +192,7 @@ struct sctp_datamsg *sctp_datamsg_from_user(struct sctp_association *asoc, if (unlikely(!max_data)) { max_data = sctp_min_frag_point(sctp_sk(asoc->base.sk), sctp_datachk_len(&asoc->stream)); - pr_warn_ratelimited("%s: asoc:%p frag_point is zero, forcing max_data to default minimum (%Zu)", + pr_warn_ratelimited("%s: asoc:%p frag_point is zero, forcing max_data to default minimum (%zu)", __func__, asoc, max_data); } diff --git a/net/sctp/diag.c b/net/sctp/diag.c index 078f01a8d582..435847d98b51 100644 --- a/net/sctp/diag.c +++ b/net/sctp/diag.c @@ -256,6 +256,7 @@ static size_t inet_assoc_attr_size(struct sctp_association *asoc) + nla_total_size(1) /* INET_DIAG_TOS */ + nla_total_size(1) /* INET_DIAG_TCLASS */ + nla_total_size(4) /* INET_DIAG_MARK */ + + nla_total_size(4) /* INET_DIAG_CLASS_ID */ + nla_total_size(addrlen * asoc->peer.transport_count) + nla_total_size(addrlen * addrcnt) + nla_total_size(sizeof(struct inet_diag_meminfo)) diff --git a/net/sctp/endpointola.c b/net/sctp/endpointola.c index 40c7eb941bc9..0448b68fce74 100644 --- a/net/sctp/endpointola.c +++ b/net/sctp/endpointola.c @@ -107,6 +107,13 @@ static struct sctp_endpoint *sctp_endpoint_init(struct sctp_endpoint *ep, auth_chunks->param_hdr.length = htons(sizeof(struct sctp_paramhdr) + 2); } + + /* Allocate and initialize transorms arrays for supported + * HMACs. + */ + err = sctp_auth_init_hmacs(ep, gfp); + if (err) + goto nomem; } /* Initialize the base structure. */ @@ -150,15 +157,10 @@ static struct sctp_endpoint *sctp_endpoint_init(struct sctp_endpoint *ep, INIT_LIST_HEAD(&ep->endpoint_shared_keys); null_key = sctp_auth_shkey_create(0, gfp); if (!null_key) - goto nomem; + goto nomem_shkey; list_add(&null_key->key_list, &ep->endpoint_shared_keys); - /* Allocate and initialize transorms arrays for supported HMACs. */ - err = sctp_auth_init_hmacs(ep, gfp); - if (err) - goto nomem_hmacs; - /* Add the null key to the endpoint shared keys list and * set the hmcas and chunks pointers. */ @@ -169,8 +171,8 @@ static struct sctp_endpoint *sctp_endpoint_init(struct sctp_endpoint *ep, return ep; -nomem_hmacs: - sctp_auth_destroy_keys(&ep->endpoint_shared_keys); +nomem_shkey: + sctp_auth_destroy_hmacs(ep->auth_hmacs); nomem: /* Free all allocations */ kfree(auth_hmacs); diff --git a/net/sctp/ipv6.c b/net/sctp/ipv6.c index ed8e006dae85..6200cd2b4b99 100644 --- a/net/sctp/ipv6.c +++ b/net/sctp/ipv6.c @@ -280,7 +280,8 @@ static void sctp_v6_get_dst(struct sctp_transport *t, union sctp_addr *saddr, if (saddr) { fl6->saddr = saddr->v6.sin6_addr; - fl6->fl6_sport = saddr->v6.sin6_port; + if (!fl6->fl6_sport) + fl6->fl6_sport = saddr->v6.sin6_port; pr_debug("src=%pI6 - ", &fl6->saddr); } diff --git a/net/sctp/offload.c b/net/sctp/offload.c index 123e9f2dc226..edfcf16e704c 100644 --- a/net/sctp/offload.c +++ b/net/sctp/offload.c @@ -36,6 +36,7 @@ static __le32 sctp_gso_make_checksum(struct sk_buff *skb) { skb->ip_summed = CHECKSUM_NONE; skb->csum_not_inet = 0; + gso_reset_checksum(skb, ~0); return sctp_compute_cksum(skb, skb_transport_offset(skb)); } diff --git a/net/sctp/outqueue.c b/net/sctp/outqueue.c index c37e1c2dec9d..fd33281999b5 100644 --- a/net/sctp/outqueue.c +++ b/net/sctp/outqueue.c @@ -212,7 +212,7 @@ void sctp_outq_init(struct sctp_association *asoc, struct sctp_outq *q) INIT_LIST_HEAD(&q->retransmit); INIT_LIST_HEAD(&q->sacked); INIT_LIST_HEAD(&q->abandoned); - sctp_sched_set_sched(asoc, SCTP_SS_DEFAULT); + sctp_sched_set_sched(asoc, sctp_sk(asoc->base.sk)->default_ss); } /* Free the outqueue structure and any related pending chunks. diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c index 4e0eeb113ef5..6abc8b274270 100644 --- a/net/sctp/protocol.c +++ b/net/sctp/protocol.c @@ -440,7 +440,8 @@ static void sctp_v4_get_dst(struct sctp_transport *t, union sctp_addr *saddr, } if (saddr) { fl4->saddr = saddr->v4.sin_addr.s_addr; - fl4->fl4_sport = saddr->v4.sin_port; + if (!fl4->fl4_sport) + fl4->fl4_sport = saddr->v4.sin_port; } pr_debug("%s: dst:%pI4, src:%pI4 - ", __func__, &fl4->daddr, diff --git a/net/sctp/sm_make_chunk.c b/net/sctp/sm_make_chunk.c index f4ac6c592e13..d05c57664e36 100644 --- a/net/sctp/sm_make_chunk.c +++ b/net/sctp/sm_make_chunk.c @@ -495,7 +495,10 @@ struct sctp_chunk *sctp_make_init_ack(const struct sctp_association *asoc, * * [INIT ACK back to where the INIT came from.] */ - retval->transport = chunk->transport; + if (chunk->transport) + retval->transport = + sctp_assoc_lookup_paddr(asoc, + &chunk->transport->ipaddr); retval->subh.init_hdr = sctp_addto_chunk(retval, sizeof(initack), &initack); @@ -642,8 +645,10 @@ struct sctp_chunk *sctp_make_cookie_ack(const struct sctp_association *asoc, * * [COOKIE ACK back to where the COOKIE ECHO came from.] */ - if (retval && chunk) - retval->transport = chunk->transport; + if (retval && chunk && chunk->transport) + retval->transport = + sctp_assoc_lookup_paddr(asoc, + &chunk->transport->ipaddr); return retval; } diff --git a/net/sctp/socket.c b/net/sctp/socket.c index f93c3cf9e567..6140471efd4b 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -102,9 +102,9 @@ static int sctp_send_asconf(struct sctp_association *asoc, struct sctp_chunk *chunk); static int sctp_do_bind(struct sock *, union sctp_addr *, int); static int sctp_autobind(struct sock *sk); -static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, - struct sctp_association *assoc, - enum sctp_socket_type type); +static int sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, + struct sctp_association *assoc, + enum sctp_socket_type type); static unsigned long sctp_memory_pressure; static atomic_long_t sctp_memory_allocated; @@ -248,7 +248,7 @@ struct sctp_association *sctp_id2assoc(struct sock *sk, sctp_assoc_t id) } /* Otherwise this is a UDP-style socket. */ - if (!id || (id == (sctp_assoc_t)-1)) + if (id <= SCTP_ALL_ASSOC) return NULL; spin_lock_bh(&sctp_assocs_id_lock); @@ -1866,6 +1866,7 @@ static int sctp_sendmsg_check_sflags(struct sctp_association *asoc, pr_debug("%s: aborting association:%p\n", __func__, asoc); sctp_primitive_ABORT(net, asoc, chunk); + iov_iter_revert(&msg->msg_iter, msg_len); return 0; } @@ -2027,7 +2028,7 @@ static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len) struct sctp_endpoint *ep = sctp_sk(sk)->ep; struct sctp_transport *transport = NULL; struct sctp_sndrcvinfo _sinfo, *sinfo; - struct sctp_association *asoc; + struct sctp_association *asoc, *tmp; struct sctp_cmsgs cmsgs; union sctp_addr *daddr; bool new = false; @@ -2053,7 +2054,7 @@ static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len) /* SCTP_SENDALL process */ if ((sflags & SCTP_SENDALL) && sctp_style(sk, UDP)) { - list_for_each_entry(asoc, &ep->asocs, asocs) { + list_for_each_entry_safe(asoc, tmp, &ep->asocs, asocs) { err = sctp_sendmsg_check_sflags(asoc, sflags, msg, msg_len); if (err == 0) @@ -2750,12 +2751,13 @@ static int sctp_setsockopt_peer_addr_params(struct sock *sk, return -EINVAL; } - /* Get association, if assoc_id != 0 and the socket is a one - * to many style socket, and an association was not found, then - * the id was invalid. + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. */ asoc = sctp_id2assoc(sk, params.spp_assoc_id); - if (!asoc && params.spp_assoc_id && sctp_style(sk, UDP)) + if (!asoc && params.spp_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; /* Heartbeat demand can only be sent on a transport or @@ -2797,6 +2799,43 @@ static inline __u32 sctp_spp_sackdelay_disable(__u32 param_flags) return (param_flags & ~SPP_SACKDELAY) | SPP_SACKDELAY_DISABLE; } +static void sctp_apply_asoc_delayed_ack(struct sctp_sack_info *params, + struct sctp_association *asoc) +{ + struct sctp_transport *trans; + + if (params->sack_delay) { + asoc->sackdelay = msecs_to_jiffies(params->sack_delay); + asoc->param_flags = + sctp_spp_sackdelay_enable(asoc->param_flags); + } + if (params->sack_freq == 1) { + asoc->param_flags = + sctp_spp_sackdelay_disable(asoc->param_flags); + } else if (params->sack_freq > 1) { + asoc->sackfreq = params->sack_freq; + asoc->param_flags = + sctp_spp_sackdelay_enable(asoc->param_flags); + } + + list_for_each_entry(trans, &asoc->peer.transport_addr_list, + transports) { + if (params->sack_delay) { + trans->sackdelay = msecs_to_jiffies(params->sack_delay); + trans->param_flags = + sctp_spp_sackdelay_enable(trans->param_flags); + } + if (params->sack_freq == 1) { + trans->param_flags = + sctp_spp_sackdelay_disable(trans->param_flags); + } else if (params->sack_freq > 1) { + trans->sackfreq = params->sack_freq; + trans->param_flags = + sctp_spp_sackdelay_enable(trans->param_flags); + } + } +} + /* * 7.1.23. Get or set delayed ack timer (SCTP_DELAYED_SACK) * @@ -2836,10 +2875,9 @@ static inline __u32 sctp_spp_sackdelay_disable(__u32 param_flags) static int sctp_setsockopt_delayed_ack(struct sock *sk, char __user *optval, unsigned int optlen) { - struct sctp_sack_info params; - struct sctp_transport *trans = NULL; - struct sctp_association *asoc = NULL; - struct sctp_sock *sp = sctp_sk(sk); + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + struct sctp_sack_info params; if (optlen == sizeof(struct sctp_sack_info)) { if (copy_from_user(¶ms, optval, optlen)) @@ -2867,67 +2905,42 @@ static int sctp_setsockopt_delayed_ack(struct sock *sk, if (params.sack_delay > 500) return -EINVAL; - /* Get association, if sack_assoc_id != 0 and the socket is a one - * to many style socket, and an association was not found, then - * the id was invalid. + /* Get association, if sack_assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. */ asoc = sctp_id2assoc(sk, params.sack_assoc_id); - if (!asoc && params.sack_assoc_id && sctp_style(sk, UDP)) + if (!asoc && params.sack_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; - if (params.sack_delay) { - if (asoc) { - asoc->sackdelay = - msecs_to_jiffies(params.sack_delay); - asoc->param_flags = - sctp_spp_sackdelay_enable(asoc->param_flags); - } else { + if (asoc) { + sctp_apply_asoc_delayed_ack(¶ms, asoc); + + return 0; + } + + if (params.sack_assoc_id == SCTP_FUTURE_ASSOC || + params.sack_assoc_id == SCTP_ALL_ASSOC) { + if (params.sack_delay) { sp->sackdelay = params.sack_delay; sp->param_flags = sctp_spp_sackdelay_enable(sp->param_flags); } - } - - if (params.sack_freq == 1) { - if (asoc) { - asoc->param_flags = - sctp_spp_sackdelay_disable(asoc->param_flags); - } else { + if (params.sack_freq == 1) { sp->param_flags = sctp_spp_sackdelay_disable(sp->param_flags); - } - } else if (params.sack_freq > 1) { - if (asoc) { - asoc->sackfreq = params.sack_freq; - asoc->param_flags = - sctp_spp_sackdelay_enable(asoc->param_flags); - } else { + } else if (params.sack_freq > 1) { sp->sackfreq = params.sack_freq; sp->param_flags = sctp_spp_sackdelay_enable(sp->param_flags); } } - /* If change is for association, also apply to each transport. */ - if (asoc) { - list_for_each_entry(trans, &asoc->peer.transport_addr_list, - transports) { - if (params.sack_delay) { - trans->sackdelay = - msecs_to_jiffies(params.sack_delay); - trans->param_flags = - sctp_spp_sackdelay_enable(trans->param_flags); - } - if (params.sack_freq == 1) { - trans->param_flags = - sctp_spp_sackdelay_disable(trans->param_flags); - } else if (params.sack_freq > 1) { - trans->sackfreq = params.sack_freq; - trans->param_flags = - sctp_spp_sackdelay_enable(trans->param_flags); - } - } - } + if (params.sack_assoc_id == SCTP_CURRENT_ASSOC || + params.sack_assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + sctp_apply_asoc_delayed_ack(¶ms, asoc); return 0; } @@ -2997,15 +3010,22 @@ static int sctp_setsockopt_default_send_param(struct sock *sk, return -EINVAL; asoc = sctp_id2assoc(sk, info.sinfo_assoc_id); - if (!asoc && info.sinfo_assoc_id && sctp_style(sk, UDP)) + if (!asoc && info.sinfo_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; + if (asoc) { asoc->default_stream = info.sinfo_stream; asoc->default_flags = info.sinfo_flags; asoc->default_ppid = info.sinfo_ppid; asoc->default_context = info.sinfo_context; asoc->default_timetolive = info.sinfo_timetolive; - } else { + + return 0; + } + + if (info.sinfo_assoc_id == SCTP_FUTURE_ASSOC || + info.sinfo_assoc_id == SCTP_ALL_ASSOC) { sp->default_stream = info.sinfo_stream; sp->default_flags = info.sinfo_flags; sp->default_ppid = info.sinfo_ppid; @@ -3013,6 +3033,17 @@ static int sctp_setsockopt_default_send_param(struct sock *sk, sp->default_timetolive = info.sinfo_timetolive; } + if (info.sinfo_assoc_id == SCTP_CURRENT_ASSOC || + info.sinfo_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + asoc->default_stream = info.sinfo_stream; + asoc->default_flags = info.sinfo_flags; + asoc->default_ppid = info.sinfo_ppid; + asoc->default_context = info.sinfo_context; + asoc->default_timetolive = info.sinfo_timetolive; + } + } + return 0; } @@ -3037,20 +3068,37 @@ static int sctp_setsockopt_default_sndinfo(struct sock *sk, return -EINVAL; asoc = sctp_id2assoc(sk, info.snd_assoc_id); - if (!asoc && info.snd_assoc_id && sctp_style(sk, UDP)) + if (!asoc && info.snd_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; + if (asoc) { asoc->default_stream = info.snd_sid; asoc->default_flags = info.snd_flags; asoc->default_ppid = info.snd_ppid; asoc->default_context = info.snd_context; - } else { + + return 0; + } + + if (info.snd_assoc_id == SCTP_FUTURE_ASSOC || + info.snd_assoc_id == SCTP_ALL_ASSOC) { sp->default_stream = info.snd_sid; sp->default_flags = info.snd_flags; sp->default_ppid = info.snd_ppid; sp->default_context = info.snd_context; } + if (info.snd_assoc_id == SCTP_CURRENT_ASSOC || + info.snd_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + asoc->default_stream = info.snd_sid; + asoc->default_flags = info.snd_flags; + asoc->default_ppid = info.snd_ppid; + asoc->default_context = info.snd_context; + } + } + return 0; } @@ -3144,7 +3192,8 @@ static int sctp_setsockopt_rtoinfo(struct sock *sk, char __user *optval, unsigne asoc = sctp_id2assoc(sk, rtoinfo.srto_assoc_id); /* Set the values to the specific association */ - if (!asoc && rtoinfo.srto_assoc_id && sctp_style(sk, UDP)) + if (!asoc && rtoinfo.srto_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; rto_max = rtoinfo.srto_max; @@ -3206,7 +3255,8 @@ static int sctp_setsockopt_associnfo(struct sock *sk, char __user *optval, unsig asoc = sctp_id2assoc(sk, assocparams.sasoc_assoc_id); - if (!asoc && assocparams.sasoc_assoc_id && sctp_style(sk, UDP)) + if (!asoc && assocparams.sasoc_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; /* Set the values to the specific association */ @@ -3319,7 +3369,7 @@ static int sctp_setsockopt_maxseg(struct sock *sk, char __user *optval, unsigned current->comm, task_pid_nr(current)); if (copy_from_user(&val, optval, optlen)) return -EFAULT; - params.assoc_id = 0; + params.assoc_id = SCTP_FUTURE_ASSOC; } else if (optlen == sizeof(struct sctp_assoc_value)) { if (copy_from_user(¶ms, optval, optlen)) return -EFAULT; @@ -3329,6 +3379,9 @@ static int sctp_setsockopt_maxseg(struct sock *sk, char __user *optval, unsigned } asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; if (val) { int min_len, max_len; @@ -3346,8 +3399,6 @@ static int sctp_setsockopt_maxseg(struct sock *sk, char __user *optval, unsigned asoc->user_frag = val; sctp_assoc_update_frag_point(asoc); } else { - if (params.assoc_id && sctp_style(sk, UDP)) - return -EINVAL; sp->user_frag = val; } @@ -3460,8 +3511,8 @@ static int sctp_setsockopt_adaptation_layer(struct sock *sk, char __user *optval static int sctp_setsockopt_context(struct sock *sk, char __user *optval, unsigned int optlen) { + struct sctp_sock *sp = sctp_sk(sk); struct sctp_assoc_value params; - struct sctp_sock *sp; struct sctp_association *asoc; if (optlen != sizeof(struct sctp_assoc_value)) @@ -3469,17 +3520,26 @@ static int sctp_setsockopt_context(struct sock *sk, char __user *optval, if (copy_from_user(¶ms, optval, optlen)) return -EFAULT; - sp = sctp_sk(sk); + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; - if (params.assoc_id != 0) { - asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc) - return -EINVAL; + if (asoc) { asoc->default_rcv_context = params.assoc_value; - } else { - sp->default_rcv_context = params.assoc_value; + + return 0; } + if (params.assoc_id == SCTP_FUTURE_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + sp->default_rcv_context = params.assoc_value; + + if (params.assoc_id == SCTP_CURRENT_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + asoc->default_rcv_context = params.assoc_value; + return 0; } @@ -3580,11 +3640,9 @@ static int sctp_setsockopt_maxburst(struct sock *sk, char __user *optval, unsigned int optlen) { + struct sctp_sock *sp = sctp_sk(sk); struct sctp_assoc_value params; - struct sctp_sock *sp; struct sctp_association *asoc; - int val; - int assoc_id = 0; if (optlen == sizeof(int)) { pr_warn_ratelimited(DEPRECATED @@ -3592,25 +3650,34 @@ static int sctp_setsockopt_maxburst(struct sock *sk, "Use of int in max_burst socket option deprecated.\n" "Use struct sctp_assoc_value instead\n", current->comm, task_pid_nr(current)); - if (copy_from_user(&val, optval, optlen)) + if (copy_from_user(¶ms.assoc_value, optval, optlen)) return -EFAULT; + params.assoc_id = SCTP_FUTURE_ASSOC; } else if (optlen == sizeof(struct sctp_assoc_value)) { if (copy_from_user(¶ms, optval, optlen)) return -EFAULT; - val = params.assoc_value; - assoc_id = params.assoc_id; } else return -EINVAL; - sp = sctp_sk(sk); + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; - if (assoc_id != 0) { - asoc = sctp_id2assoc(sk, assoc_id); - if (!asoc) - return -EINVAL; - asoc->max_burst = val; - } else - sp->max_burst = val; + if (asoc) { + asoc->max_burst = params.assoc_value; + + return 0; + } + + if (params.assoc_id == SCTP_FUTURE_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + sp->max_burst = params.assoc_value; + + if (params.assoc_id == SCTP_CURRENT_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + asoc->max_burst = params.assoc_value; return 0; } @@ -3702,7 +3769,7 @@ static int sctp_setsockopt_auth_key(struct sock *sk, struct sctp_endpoint *ep = sctp_sk(sk)->ep; struct sctp_authkey *authkey; struct sctp_association *asoc; - int ret; + int ret = -EINVAL; if (!ep->auth_enable) return -EACCES; @@ -3712,25 +3779,44 @@ static int sctp_setsockopt_auth_key(struct sock *sk, /* authkey->sca_keylength is u16, so optlen can't be bigger than * this. */ - optlen = min_t(unsigned int, optlen, USHRT_MAX + - sizeof(struct sctp_authkey)); + optlen = min_t(unsigned int, optlen, USHRT_MAX + sizeof(*authkey)); authkey = memdup_user(optval, optlen); if (IS_ERR(authkey)) return PTR_ERR(authkey); - if (authkey->sca_keylength > optlen - sizeof(struct sctp_authkey)) { - ret = -EINVAL; + if (authkey->sca_keylength > optlen - sizeof(*authkey)) goto out; - } asoc = sctp_id2assoc(sk, authkey->sca_assoc_id); - if (!asoc && authkey->sca_assoc_id && sctp_style(sk, UDP)) { - ret = -EINVAL; + if (!asoc && authkey->sca_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + goto out; + + if (asoc) { + ret = sctp_auth_set_key(ep, asoc, authkey); goto out; } - ret = sctp_auth_set_key(ep, asoc, authkey); + if (authkey->sca_assoc_id == SCTP_FUTURE_ASSOC || + authkey->sca_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_set_key(ep, asoc, authkey); + if (ret) + goto out; + } + + ret = 0; + + if (authkey->sca_assoc_id == SCTP_CURRENT_ASSOC || + authkey->sca_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_set_key(ep, asoc, authkey); + + if (res && !ret) + ret = res; + } + } + out: kzfree(authkey); return ret; @@ -3747,8 +3833,9 @@ static int sctp_setsockopt_active_key(struct sock *sk, unsigned int optlen) { struct sctp_endpoint *ep = sctp_sk(sk)->ep; - struct sctp_authkeyid val; struct sctp_association *asoc; + struct sctp_authkeyid val; + int ret = 0; if (!ep->auth_enable) return -EACCES; @@ -3759,10 +3846,32 @@ static int sctp_setsockopt_active_key(struct sock *sk, return -EFAULT; asoc = sctp_id2assoc(sk, val.scact_assoc_id); - if (!asoc && val.scact_assoc_id && sctp_style(sk, UDP)) + if (!asoc && val.scact_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; - return sctp_auth_set_active_key(ep, asoc, val.scact_keynumber); + if (asoc) + return sctp_auth_set_active_key(ep, asoc, val.scact_keynumber); + + if (val.scact_assoc_id == SCTP_FUTURE_ASSOC || + val.scact_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_set_active_key(ep, asoc, val.scact_keynumber); + if (ret) + return ret; + } + + if (val.scact_assoc_id == SCTP_CURRENT_ASSOC || + val.scact_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_set_active_key(ep, asoc, + val.scact_keynumber); + + if (res && !ret) + ret = res; + } + } + + return ret; } /* @@ -3775,8 +3884,9 @@ static int sctp_setsockopt_del_key(struct sock *sk, unsigned int optlen) { struct sctp_endpoint *ep = sctp_sk(sk)->ep; - struct sctp_authkeyid val; struct sctp_association *asoc; + struct sctp_authkeyid val; + int ret = 0; if (!ep->auth_enable) return -EACCES; @@ -3787,11 +3897,32 @@ static int sctp_setsockopt_del_key(struct sock *sk, return -EFAULT; asoc = sctp_id2assoc(sk, val.scact_assoc_id); - if (!asoc && val.scact_assoc_id && sctp_style(sk, UDP)) + if (!asoc && val.scact_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; - return sctp_auth_del_key_id(ep, asoc, val.scact_keynumber); + if (asoc) + return sctp_auth_del_key_id(ep, asoc, val.scact_keynumber); + + if (val.scact_assoc_id == SCTP_FUTURE_ASSOC || + val.scact_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_del_key_id(ep, asoc, val.scact_keynumber); + if (ret) + return ret; + } + + if (val.scact_assoc_id == SCTP_CURRENT_ASSOC || + val.scact_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_del_key_id(ep, asoc, + val.scact_keynumber); + + if (res && !ret) + ret = res; + } + } + return ret; } /* @@ -3803,8 +3934,9 @@ static int sctp_setsockopt_deactivate_key(struct sock *sk, char __user *optval, unsigned int optlen) { struct sctp_endpoint *ep = sctp_sk(sk)->ep; - struct sctp_authkeyid val; struct sctp_association *asoc; + struct sctp_authkeyid val; + int ret = 0; if (!ep->auth_enable) return -EACCES; @@ -3815,10 +3947,32 @@ static int sctp_setsockopt_deactivate_key(struct sock *sk, char __user *optval, return -EFAULT; asoc = sctp_id2assoc(sk, val.scact_assoc_id); - if (!asoc && val.scact_assoc_id && sctp_style(sk, UDP)) + if (!asoc && val.scact_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; - return sctp_auth_deact_key_id(ep, asoc, val.scact_keynumber); + if (asoc) + return sctp_auth_deact_key_id(ep, asoc, val.scact_keynumber); + + if (val.scact_assoc_id == SCTP_FUTURE_ASSOC || + val.scact_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_deact_key_id(ep, asoc, val.scact_keynumber); + if (ret) + return ret; + } + + if (val.scact_assoc_id == SCTP_CURRENT_ASSOC || + val.scact_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_deact_key_id(ep, asoc, + val.scact_keynumber); + + if (res && !ret) + ret = res; + } + } + + return ret; } /* @@ -3884,11 +4038,25 @@ static int sctp_setsockopt_paddr_thresholds(struct sock *sk, sizeof(struct sctp_paddrthlds))) return -EFAULT; - - if (sctp_is_any(sk, (const union sctp_addr *)&val.spt_address)) { - asoc = sctp_id2assoc(sk, val.spt_assoc_id); - if (!asoc) + if (!sctp_is_any(sk, (const union sctp_addr *)&val.spt_address)) { + trans = sctp_addr_id2transport(sk, &val.spt_address, + val.spt_assoc_id); + if (!trans) return -ENOENT; + + if (val.spt_pathmaxrxt) + trans->pathmaxrxt = val.spt_pathmaxrxt; + trans->pf_retrans = val.spt_pathpfthld; + + return 0; + } + + asoc = sctp_id2assoc(sk, val.spt_assoc_id); + if (!asoc && val.spt_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { list_for_each_entry(trans, &asoc->peer.transport_addr_list, transports) { if (val.spt_pathmaxrxt) @@ -3900,14 +4068,11 @@ static int sctp_setsockopt_paddr_thresholds(struct sock *sk, asoc->pathmaxrxt = val.spt_pathmaxrxt; asoc->pf_retrans = val.spt_pathpfthld; } else { - trans = sctp_addr_id2transport(sk, &val.spt_address, - val.spt_assoc_id); - if (!trans) - return -ENOENT; + struct sctp_sock *sp = sctp_sk(sk); if (val.spt_pathmaxrxt) - trans->pathmaxrxt = val.spt_pathmaxrxt; - trans->pf_retrans = val.spt_pathpfthld; + sp->pathmaxrxt = val.spt_pathmaxrxt; + sp->pf_retrans = val.spt_pathpfthld; } return 0; @@ -3950,6 +4115,7 @@ static int sctp_setsockopt_pr_supported(struct sock *sk, unsigned int optlen) { struct sctp_assoc_value params; + struct sctp_association *asoc; if (optlen != sizeof(params)) return -EINVAL; @@ -3957,6 +4123,11 @@ static int sctp_setsockopt_pr_supported(struct sock *sk, if (copy_from_user(¶ms, optval, optlen)) return -EFAULT; + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + sctp_sk(sk)->ep->prsctp_enable = !!params.assoc_value; return 0; @@ -3966,6 +4137,7 @@ static int sctp_setsockopt_default_prinfo(struct sock *sk, char __user *optval, unsigned int optlen) { + struct sctp_sock *sp = sctp_sk(sk); struct sctp_default_prinfo info; struct sctp_association *asoc; int retval = -EINVAL; @@ -3985,19 +4157,31 @@ static int sctp_setsockopt_default_prinfo(struct sock *sk, info.pr_value = 0; asoc = sctp_id2assoc(sk, info.pr_assoc_id); + if (!asoc && info.pr_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + goto out; + + retval = 0; + if (asoc) { SCTP_PR_SET_POLICY(asoc->default_flags, info.pr_policy); asoc->default_timetolive = info.pr_value; - } else if (!info.pr_assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); + goto out; + } + if (info.pr_assoc_id == SCTP_FUTURE_ASSOC || + info.pr_assoc_id == SCTP_ALL_ASSOC) { SCTP_PR_SET_POLICY(sp->default_flags, info.pr_policy); sp->default_timetolive = info.pr_value; - } else { - goto out; } - retval = 0; + if (info.pr_assoc_id == SCTP_CURRENT_ASSOC || + info.pr_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + SCTP_PR_SET_POLICY(asoc->default_flags, info.pr_policy); + asoc->default_timetolive = info.pr_value; + } + } out: return retval; @@ -4020,15 +4204,14 @@ static int sctp_setsockopt_reconfig_supported(struct sock *sk, } asoc = sctp_id2assoc(sk, params.assoc_id); - if (asoc) { - asoc->reconf_enable = !!params.assoc_value; - } else if (!params.assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); - - sp->ep->reconf_enable = !!params.assoc_value; - } else { + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) goto out; - } + + if (asoc) + asoc->reconf_enable = !!params.assoc_value; + else + sctp_sk(sk)->ep->reconf_enable = !!params.assoc_value; retval = 0; @@ -4040,6 +4223,7 @@ static int sctp_setsockopt_enable_strreset(struct sock *sk, char __user *optval, unsigned int optlen) { + struct sctp_endpoint *ep = sctp_sk(sk)->ep; struct sctp_assoc_value params; struct sctp_association *asoc; int retval = -EINVAL; @@ -4056,17 +4240,25 @@ static int sctp_setsockopt_enable_strreset(struct sock *sk, goto out; asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + goto out; + + retval = 0; + if (asoc) { asoc->strreset_enable = params.assoc_value; - } else if (!params.assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); - - sp->ep->strreset_enable = params.assoc_value; - } else { goto out; } - retval = 0; + if (params.assoc_id == SCTP_FUTURE_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + ep->strreset_enable = params.assoc_value; + + if (params.assoc_id == SCTP_CURRENT_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &ep->asocs, asocs) + asoc->strreset_enable = params.assoc_value; out: return retval; @@ -4161,29 +4353,44 @@ static int sctp_setsockopt_scheduler(struct sock *sk, char __user *optval, unsigned int optlen) { + struct sctp_sock *sp = sctp_sk(sk); struct sctp_association *asoc; struct sctp_assoc_value params; - int retval = -EINVAL; + int retval = 0; if (optlen < sizeof(params)) - goto out; + return -EINVAL; optlen = sizeof(params); - if (copy_from_user(¶ms, optval, optlen)) { - retval = -EFAULT; - goto out; - } + if (copy_from_user(¶ms, optval, optlen)) + return -EFAULT; if (params.assoc_value > SCTP_SS_MAX) - goto out; + return -EINVAL; asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc) - goto out; + if (!asoc && params.assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; - retval = sctp_sched_set_sched(asoc, params.assoc_value); + if (asoc) + return sctp_sched_set_sched(asoc, params.assoc_value); + + if (params.assoc_id == SCTP_FUTURE_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) + sp->default_ss = params.assoc_value; + + if (params.assoc_id == SCTP_CURRENT_ASSOC || + params.assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + int ret = sctp_sched_set_sched(asoc, + params.assoc_value); + + if (ret && !retval) + retval = ret; + } + } -out: return retval; } @@ -4191,8 +4398,8 @@ static int sctp_setsockopt_scheduler_value(struct sock *sk, char __user *optval, unsigned int optlen) { - struct sctp_association *asoc; struct sctp_stream_value params; + struct sctp_association *asoc; int retval = -EINVAL; if (optlen < sizeof(params)) @@ -4205,11 +4412,24 @@ static int sctp_setsockopt_scheduler_value(struct sock *sk, } asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc) + if (!asoc && params.assoc_id != SCTP_CURRENT_ASSOC && + sctp_style(sk, UDP)) + goto out; + + if (asoc) { + retval = sctp_sched_set_value(asoc, params.stream_id, + params.stream_value, GFP_KERNEL); goto out; + } - retval = sctp_sched_set_value(asoc, params.stream_id, - params.stream_value, GFP_KERNEL); + retval = 0; + + list_for_each_entry(asoc, &sctp_sk(sk)->ep->asocs, asocs) { + int ret = sctp_sched_set_value(asoc, params.stream_id, + params.stream_value, GFP_KERNEL); + if (ret && !retval) /* try to return the 1st error. */ + retval = ret; + } out: return retval; @@ -4220,8 +4440,8 @@ static int sctp_setsockopt_interleaving_supported(struct sock *sk, unsigned int optlen) { struct sctp_sock *sp = sctp_sk(sk); - struct net *net = sock_net(sk); struct sctp_assoc_value params; + struct sctp_association *asoc; int retval = -EINVAL; if (optlen < sizeof(params)) @@ -4233,10 +4453,12 @@ static int sctp_setsockopt_interleaving_supported(struct sock *sk, goto out; } - if (params.assoc_id) + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) goto out; - if (!net->sctp.intl_enable || !sp->frag_interleave) { + if (!sock_net(sk)->sctp.intl_enable || !sp->frag_interleave) { retval = -EPERM; goto out; } @@ -4271,54 +4493,69 @@ static int sctp_setsockopt_reuse_port(struct sock *sk, char __user *optval, return 0; } +static int sctp_assoc_ulpevent_type_set(struct sctp_event *param, + struct sctp_association *asoc) +{ + struct sctp_ulpevent *event; + + sctp_ulpevent_type_set(&asoc->subscribe, param->se_type, param->se_on); + + if (param->se_type == SCTP_SENDER_DRY_EVENT && param->se_on) { + if (sctp_outq_is_empty(&asoc->outqueue)) { + event = sctp_ulpevent_make_sender_dry_event(asoc, + GFP_USER | __GFP_NOWARN); + if (!event) + return -ENOMEM; + + asoc->stream.si->enqueue_event(&asoc->ulpq, event); + } + } + + return 0; +} + static int sctp_setsockopt_event(struct sock *sk, char __user *optval, unsigned int optlen) { + struct sctp_sock *sp = sctp_sk(sk); struct sctp_association *asoc; - struct sctp_ulpevent *event; struct sctp_event param; int retval = 0; - if (optlen < sizeof(param)) { - retval = -EINVAL; - goto out; - } + if (optlen < sizeof(param)) + return -EINVAL; optlen = sizeof(param); - if (copy_from_user(¶m, optval, optlen)) { - retval = -EFAULT; - goto out; - } + if (copy_from_user(¶m, optval, optlen)) + return -EFAULT; if (param.se_type < SCTP_SN_TYPE_BASE || - param.se_type > SCTP_SN_TYPE_MAX) { - retval = -EINVAL; - goto out; - } + param.se_type > SCTP_SN_TYPE_MAX) + return -EINVAL; asoc = sctp_id2assoc(sk, param.se_assoc_id); - if (!asoc) { - sctp_ulpevent_type_set(&sctp_sk(sk)->subscribe, - param.se_type, param.se_on); - goto out; - } + if (!asoc && param.se_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; - sctp_ulpevent_type_set(&asoc->subscribe, param.se_type, param.se_on); + if (asoc) + return sctp_assoc_ulpevent_type_set(¶m, asoc); - if (param.se_type == SCTP_SENDER_DRY_EVENT && param.se_on) { - if (sctp_outq_is_empty(&asoc->outqueue)) { - event = sctp_ulpevent_make_sender_dry_event(asoc, - GFP_USER | __GFP_NOWARN); - if (!event) { - retval = -ENOMEM; - goto out; - } + if (param.se_assoc_id == SCTP_FUTURE_ASSOC || + param.se_assoc_id == SCTP_ALL_ASSOC) + sctp_ulpevent_type_set(&sp->subscribe, + param.se_type, param.se_on); - asoc->stream.si->enqueue_event(&asoc->ulpq, event); + if (param.se_assoc_id == SCTP_CURRENT_ASSOC || + param.se_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + int ret = sctp_assoc_ulpevent_type_set(¶m, asoc); + + if (ret && !retval) + retval = ret; } } -out: return retval; } @@ -4654,7 +4891,11 @@ static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern) /* Populate the fields of the newsk from the oldsk and migrate the * asoc to the newsk. */ - sctp_sock_migrate(sk, newsk, asoc, SCTP_SOCKET_TCP); + error = sctp_sock_migrate(sk, newsk, asoc, SCTP_SOCKET_TCP); + if (error) { + sk_common_release(newsk); + newsk = NULL; + } out: release_sock(sk); @@ -4777,12 +5018,14 @@ static int sctp_init_sock(struct sock *sk) */ sp->hbinterval = net->sctp.hb_interval; sp->pathmaxrxt = net->sctp.max_retrans_path; + sp->pf_retrans = net->sctp.pf_retrans; sp->pathmtu = 0; /* allow default discovery */ sp->sackdelay = net->sctp.sack_timeout; sp->sackfreq = 2; sp->param_flags = SPP_HB_ENABLE | SPP_PMTUD_ENABLE | SPP_SACKDELAY_ENABLE; + sp->default_ss = SCTP_SS_DEFAULT; /* If enabled no SCTP message fragmentation will be performed. * Configure through SCTP_DISABLE_FRAGMENTS socket option. @@ -5400,7 +5643,12 @@ int sctp_do_peeloff(struct sock *sk, sctp_assoc_t id, struct socket **sockp) /* Populate the fields of the newsk from the oldsk and migrate the * asoc to the newsk. */ - sctp_sock_migrate(sk, sock->sk, asoc, SCTP_SOCKET_UDP_HIGH_BANDWIDTH); + err = sctp_sock_migrate(sk, sock->sk, asoc, + SCTP_SOCKET_UDP_HIGH_BANDWIDTH); + if (err) { + sock_release(sock); + sock = NULL; + } *sockp = sock; @@ -5676,12 +5924,13 @@ static int sctp_getsockopt_peer_addr_params(struct sock *sk, int len, } } - /* Get association, if assoc_id != 0 and the socket is a one - * to many style socket, and an association was not found, then - * the id was invalid. + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. */ asoc = sctp_id2assoc(sk, params.spp_assoc_id); - if (!asoc && params.spp_assoc_id && sctp_style(sk, UDP)) { + if (!asoc && params.spp_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { pr_debug("%s: failed no association\n", __func__); return -EINVAL; } @@ -5810,19 +6059,19 @@ static int sctp_getsockopt_delayed_ack(struct sock *sk, int len, } else return -EINVAL; - /* Get association, if sack_assoc_id != 0 and the socket is a one - * to many style socket, and an association was not found, then - * the id was invalid. + /* Get association, if sack_assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. */ asoc = sctp_id2assoc(sk, params.sack_assoc_id); - if (!asoc && params.sack_assoc_id && sctp_style(sk, UDP)) + if (!asoc && params.sack_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; if (asoc) { /* Fetch association values. */ if (asoc->param_flags & SPP_SACKDELAY_ENABLE) { - params.sack_delay = jiffies_to_msecs( - asoc->sackdelay); + params.sack_delay = jiffies_to_msecs(asoc->sackdelay); params.sack_freq = asoc->sackfreq; } else { @@ -6175,8 +6424,10 @@ static int sctp_getsockopt_default_send_param(struct sock *sk, return -EFAULT; asoc = sctp_id2assoc(sk, info.sinfo_assoc_id); - if (!asoc && info.sinfo_assoc_id && sctp_style(sk, UDP)) + if (!asoc && info.sinfo_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; + if (asoc) { info.sinfo_stream = asoc->default_stream; info.sinfo_flags = asoc->default_flags; @@ -6219,8 +6470,10 @@ static int sctp_getsockopt_default_sndinfo(struct sock *sk, int len, return -EFAULT; asoc = sctp_id2assoc(sk, info.snd_assoc_id); - if (!asoc && info.snd_assoc_id && sctp_style(sk, UDP)) + if (!asoc && info.snd_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; + if (asoc) { info.snd_sid = asoc->default_stream; info.snd_flags = asoc->default_flags; @@ -6296,7 +6549,8 @@ static int sctp_getsockopt_rtoinfo(struct sock *sk, int len, asoc = sctp_id2assoc(sk, rtoinfo.srto_assoc_id); - if (!asoc && rtoinfo.srto_assoc_id && sctp_style(sk, UDP)) + if (!asoc && rtoinfo.srto_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; /* Values corresponding to the specific association. */ @@ -6353,7 +6607,8 @@ static int sctp_getsockopt_associnfo(struct sock *sk, int len, asoc = sctp_id2assoc(sk, assocparams.sasoc_assoc_id); - if (!asoc && assocparams.sasoc_assoc_id && sctp_style(sk, UDP)) + if (!asoc && assocparams.sasoc_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; /* Values correspoinding to the specific association */ @@ -6428,7 +6683,6 @@ static int sctp_getsockopt_context(struct sock *sk, int len, char __user *optval, int __user *optlen) { struct sctp_assoc_value params; - struct sctp_sock *sp; struct sctp_association *asoc; if (len < sizeof(struct sctp_assoc_value)) @@ -6439,16 +6693,13 @@ static int sctp_getsockopt_context(struct sock *sk, int len, if (copy_from_user(¶ms, optval, len)) return -EFAULT; - sp = sctp_sk(sk); + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; - if (params.assoc_id != 0) { - asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc) - return -EINVAL; - params.assoc_value = asoc->default_rcv_context; - } else { - params.assoc_value = sp->default_rcv_context; - } + params.assoc_value = asoc ? asoc->default_rcv_context + : sctp_sk(sk)->default_rcv_context; if (put_user(len, optlen)) return -EFAULT; @@ -6497,7 +6748,7 @@ static int sctp_getsockopt_maxseg(struct sock *sk, int len, "Use of int in maxseg socket option.\n" "Use struct sctp_assoc_value instead\n", current->comm, task_pid_nr(current)); - params.assoc_id = 0; + params.assoc_id = SCTP_FUTURE_ASSOC; } else if (len >= sizeof(struct sctp_assoc_value)) { len = sizeof(struct sctp_assoc_value); if (copy_from_user(¶ms, optval, len)) @@ -6506,7 +6757,8 @@ static int sctp_getsockopt_maxseg(struct sock *sk, int len, return -EINVAL; asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc && params.assoc_id && sctp_style(sk, UDP)) + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; if (asoc) @@ -6583,7 +6835,6 @@ static int sctp_getsockopt_maxburst(struct sock *sk, int len, int __user *optlen) { struct sctp_assoc_value params; - struct sctp_sock *sp; struct sctp_association *asoc; if (len == sizeof(int)) { @@ -6592,7 +6843,7 @@ static int sctp_getsockopt_maxburst(struct sock *sk, int len, "Use of int in max_burst socket option.\n" "Use struct sctp_assoc_value instead\n", current->comm, task_pid_nr(current)); - params.assoc_id = 0; + params.assoc_id = SCTP_FUTURE_ASSOC; } else if (len >= sizeof(struct sctp_assoc_value)) { len = sizeof(struct sctp_assoc_value); if (copy_from_user(¶ms, optval, len)) @@ -6600,15 +6851,12 @@ static int sctp_getsockopt_maxburst(struct sock *sk, int len, } else return -EINVAL; - sp = sctp_sk(sk); + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; - if (params.assoc_id != 0) { - asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc) - return -EINVAL; - params.assoc_value = asoc->max_burst; - } else - params.assoc_value = sp->max_burst; + params.assoc_value = asoc ? asoc->max_burst : sctp_sk(sk)->max_burst; if (len == sizeof(int)) { if (copy_to_user(optval, ¶ms.assoc_value, len)) @@ -6759,14 +7007,12 @@ static int sctp_getsockopt_local_auth_chunks(struct sock *sk, int len, to = p->gauth_chunks; asoc = sctp_id2assoc(sk, val.gauth_assoc_id); - if (!asoc && val.gauth_assoc_id && sctp_style(sk, UDP)) + if (!asoc && val.gauth_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) return -EINVAL; - if (asoc) - ch = (struct sctp_chunks_param *)asoc->c.auth_chunks; - else - ch = ep->auth_chunk_list; - + ch = asoc ? (struct sctp_chunks_param *)asoc->c.auth_chunks + : ep->auth_chunk_list; if (!ch) goto num; @@ -6911,14 +7157,7 @@ static int sctp_getsockopt_paddr_thresholds(struct sock *sk, if (copy_from_user(&val, (struct sctp_paddrthlds __user *)optval, len)) return -EFAULT; - if (sctp_is_any(sk, (const union sctp_addr *)&val.spt_address)) { - asoc = sctp_id2assoc(sk, val.spt_assoc_id); - if (!asoc) - return -ENOENT; - - val.spt_pathpfthld = asoc->pf_retrans; - val.spt_pathmaxrxt = asoc->pathmaxrxt; - } else { + if (!sctp_is_any(sk, (const union sctp_addr *)&val.spt_address)) { trans = sctp_addr_id2transport(sk, &val.spt_address, val.spt_assoc_id); if (!trans) @@ -6926,6 +7165,23 @@ static int sctp_getsockopt_paddr_thresholds(struct sock *sk, val.spt_pathmaxrxt = trans->pathmaxrxt; val.spt_pathpfthld = trans->pf_retrans; + + return 0; + } + + asoc = sctp_id2assoc(sk, val.spt_assoc_id); + if (!asoc && val.spt_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + val.spt_pathpfthld = asoc->pf_retrans; + val.spt_pathmaxrxt = asoc->pathmaxrxt; + } else { + struct sctp_sock *sp = sctp_sk(sk); + + val.spt_pathpfthld = sp->pf_retrans; + val.spt_pathmaxrxt = sp->pathmaxrxt; } if (put_user(len, optlen) || copy_to_user(optval, &val, len)) @@ -7056,17 +7312,15 @@ static int sctp_getsockopt_pr_supported(struct sock *sk, int len, goto out; asoc = sctp_id2assoc(sk, params.assoc_id); - if (asoc) { - params.assoc_value = asoc->prsctp_enable; - } else if (!params.assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); - - params.assoc_value = sp->ep->prsctp_enable; - } else { + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { retval = -EINVAL; goto out; } + params.assoc_value = asoc ? asoc->prsctp_enable + : sctp_sk(sk)->ep->prsctp_enable; + if (put_user(len, optlen)) goto out; @@ -7097,17 +7351,20 @@ static int sctp_getsockopt_default_prinfo(struct sock *sk, int len, goto out; asoc = sctp_id2assoc(sk, info.pr_assoc_id); + if (!asoc && info.pr_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + if (asoc) { info.pr_policy = SCTP_PR_POLICY(asoc->default_flags); info.pr_value = asoc->default_timetolive; - } else if (!info.pr_assoc_id) { + } else { struct sctp_sock *sp = sctp_sk(sk); info.pr_policy = SCTP_PR_POLICY(sp->default_flags); info.pr_value = sp->default_timetolive; - } else { - retval = -EINVAL; - goto out; } if (put_user(len, optlen)) @@ -7263,17 +7520,15 @@ static int sctp_getsockopt_reconfig_supported(struct sock *sk, int len, goto out; asoc = sctp_id2assoc(sk, params.assoc_id); - if (asoc) { - params.assoc_value = asoc->reconf_enable; - } else if (!params.assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); - - params.assoc_value = sp->ep->reconf_enable; - } else { + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { retval = -EINVAL; goto out; } + params.assoc_value = asoc ? asoc->reconf_enable + : sctp_sk(sk)->ep->reconf_enable; + if (put_user(len, optlen)) goto out; @@ -7304,17 +7559,15 @@ static int sctp_getsockopt_enable_strreset(struct sock *sk, int len, goto out; asoc = sctp_id2assoc(sk, params.assoc_id); - if (asoc) { - params.assoc_value = asoc->strreset_enable; - } else if (!params.assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); - - params.assoc_value = sp->ep->strreset_enable; - } else { + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { retval = -EINVAL; goto out; } + params.assoc_value = asoc ? asoc->strreset_enable + : sctp_sk(sk)->ep->strreset_enable; + if (put_user(len, optlen)) goto out; @@ -7345,12 +7598,14 @@ static int sctp_getsockopt_scheduler(struct sock *sk, int len, goto out; asoc = sctp_id2assoc(sk, params.assoc_id); - if (!asoc) { + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { retval = -EINVAL; goto out; } - params.assoc_value = sctp_sched_get_sched(asoc); + params.assoc_value = asoc ? sctp_sched_get_sched(asoc) + : sctp_sk(sk)->default_ss; if (put_user(len, optlen)) goto out; @@ -7424,17 +7679,15 @@ static int sctp_getsockopt_interleaving_supported(struct sock *sk, int len, goto out; asoc = sctp_id2assoc(sk, params.assoc_id); - if (asoc) { - params.assoc_value = asoc->intl_enable; - } else if (!params.assoc_id) { - struct sctp_sock *sp = sctp_sk(sk); - - params.assoc_value = sp->strm_interleave; - } else { + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { retval = -EINVAL; goto out; } + params.assoc_value = asoc ? asoc->intl_enable + : sctp_sk(sk)->strm_interleave; + if (put_user(len, optlen)) goto out; @@ -7486,6 +7739,10 @@ static int sctp_getsockopt_event(struct sock *sk, int len, char __user *optval, return -EINVAL; asoc = sctp_id2assoc(sk, param.se_assoc_id); + if (!asoc && param.se_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + subscribe = asoc ? asoc->subscribe : sctp_sk(sk)->subscribe; param.se_on = sctp_ulpevent_type_enabled(subscribe, param.se_type); @@ -8923,9 +9180,9 @@ static inline void sctp_copy_descendant(struct sock *sk_to, /* Populate the fields of the newsk from the oldsk and migrate the assoc * and its messages to the newsk. */ -static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, - struct sctp_association *assoc, - enum sctp_socket_type type) +static int sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, + struct sctp_association *assoc, + enum sctp_socket_type type) { struct sctp_sock *oldsp = sctp_sk(oldsk); struct sctp_sock *newsp = sctp_sk(newsk); @@ -8934,6 +9191,7 @@ static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, struct sk_buff *skb, *tmp; struct sctp_ulpevent *event; struct sctp_bind_hashbucket *head; + int err; /* Migrate socket buffer sizes and all the socket level options to the * new socket. @@ -8962,8 +9220,20 @@ static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, /* Copy the bind_addr list from the original endpoint to the new * endpoint so that we can handle restarts properly */ - sctp_bind_addr_dup(&newsp->ep->base.bind_addr, - &oldsp->ep->base.bind_addr, GFP_KERNEL); + err = sctp_bind_addr_dup(&newsp->ep->base.bind_addr, + &oldsp->ep->base.bind_addr, GFP_KERNEL); + if (err) + return err; + + /* New ep's auth_hmacs should be set if old ep's is set, in case + * that net->sctp.auth_enable has been changed to 0 by users and + * new ep's auth_hmacs couldn't be set in sctp_endpoint_init(). + */ + if (oldsp->ep->auth_hmacs) { + err = sctp_auth_init_hmacs(newsp->ep, GFP_KERNEL); + if (err) + return err; + } /* Move any messages in the old socket's receive queue that are for the * peeled off association to the new socket's receive queue. @@ -9048,6 +9318,8 @@ static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, } release_sock(newsk); + + return 0; } diff --git a/net/sctp/stream.c b/net/sctp/stream.c index 3892e7630f3a..b6bb68adac6e 100644 --- a/net/sctp/stream.c +++ b/net/sctp/stream.c @@ -37,53 +37,6 @@ #include <net/sctp/sm.h> #include <net/sctp/stream_sched.h> -static struct flex_array *fa_alloc(size_t elem_size, size_t elem_count, - gfp_t gfp) -{ - struct flex_array *result; - int err; - - result = flex_array_alloc(elem_size, elem_count, gfp); - if (result) { - err = flex_array_prealloc(result, 0, elem_count, gfp); - if (err) { - flex_array_free(result); - result = NULL; - } - } - - return result; -} - -static void fa_free(struct flex_array *fa) -{ - if (fa) - flex_array_free(fa); -} - -static void fa_copy(struct flex_array *fa, struct flex_array *from, - size_t index, size_t count) -{ - void *elem; - - while (count--) { - elem = flex_array_get(from, index); - flex_array_put(fa, index, elem, 0); - index++; - } -} - -static void fa_zero(struct flex_array *fa, size_t index, size_t count) -{ - void *elem; - - while (count--) { - elem = flex_array_get(fa, index); - memset(elem, 0, fa->element_size); - index++; - } -} - /* Migrates chunks from stream queues to new stream queues if needed, * but not across associations. Also, removes those chunks to streams * higher than the new max. @@ -131,53 +84,41 @@ static void sctp_stream_outq_migrate(struct sctp_stream *stream, } } - for (i = outcnt; i < stream->outcnt; i++) + for (i = outcnt; i < stream->outcnt; i++) { kfree(SCTP_SO(stream, i)->ext); + SCTP_SO(stream, i)->ext = NULL; + } } static int sctp_stream_alloc_out(struct sctp_stream *stream, __u16 outcnt, gfp_t gfp) { - struct flex_array *out; - size_t elem_size = sizeof(struct sctp_stream_out); - - out = fa_alloc(elem_size, outcnt, gfp); - if (!out) - return -ENOMEM; - - if (stream->out) { - fa_copy(out, stream->out, 0, min(outcnt, stream->outcnt)); - fa_free(stream->out); - } + int ret; - if (outcnt > stream->outcnt) - fa_zero(out, stream->outcnt, (outcnt - stream->outcnt)); + if (outcnt <= stream->outcnt) + return 0; - stream->out = out; + ret = genradix_prealloc(&stream->out, outcnt, gfp); + if (ret) + return ret; + stream->outcnt = outcnt; return 0; } static int sctp_stream_alloc_in(struct sctp_stream *stream, __u16 incnt, gfp_t gfp) { - struct flex_array *in; - size_t elem_size = sizeof(struct sctp_stream_in); - - in = fa_alloc(elem_size, incnt, gfp); - if (!in) - return -ENOMEM; - - if (stream->in) { - fa_copy(in, stream->in, 0, min(incnt, stream->incnt)); - fa_free(stream->in); - } + int ret; - if (incnt > stream->incnt) - fa_zero(in, stream->incnt, (incnt - stream->incnt)); + if (incnt <= stream->incnt) + return 0; - stream->in = in; + ret = genradix_prealloc(&stream->in, incnt, gfp); + if (ret) + return ret; + stream->incnt = incnt; return 0; } @@ -204,12 +145,9 @@ int sctp_stream_init(struct sctp_stream *stream, __u16 outcnt, __u16 incnt, if (ret) goto out; - stream->outcnt = outcnt; for (i = 0; i < stream->outcnt; i++) SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; - sched->init(stream); - in: sctp_stream_interleave_init(stream); if (!incnt) @@ -218,14 +156,11 @@ in: ret = sctp_stream_alloc_in(stream, incnt, gfp); if (ret) { sched->free(stream); - fa_free(stream->out); - stream->out = NULL; + genradix_free(&stream->out); stream->outcnt = 0; goto out; } - stream->incnt = incnt; - out: return ret; } @@ -250,8 +185,8 @@ void sctp_stream_free(struct sctp_stream *stream) sched->free(stream); for (i = 0; i < stream->outcnt; i++) kfree(SCTP_SO(stream, i)->ext); - fa_free(stream->out); - fa_free(stream->in); + genradix_free(&stream->out); + genradix_free(&stream->in); } void sctp_stream_clear(struct sctp_stream *stream) @@ -282,8 +217,8 @@ void sctp_stream_update(struct sctp_stream *stream, struct sctp_stream *new) sched->sched_all(stream); - new->out = NULL; - new->in = NULL; + new->out.tree.root = NULL; + new->in.tree.root = NULL; new->outcnt = 0; new->incnt = 0; } @@ -535,8 +470,6 @@ int sctp_send_add_streams(struct sctp_association *asoc, goto out; } - stream->outcnt = outcnt; - asoc->strreset_outstanding = !!out + !!in; out: @@ -585,9 +518,9 @@ struct sctp_chunk *sctp_process_strreset_outreq( struct sctp_strreset_outreq *outreq = param.v; struct sctp_stream *stream = &asoc->stream; __u32 result = SCTP_STRRESET_DENIED; - __u16 i, nums, flags = 0; __be16 *str_p = NULL; __u32 request_seq; + __u16 i, nums; request_seq = ntohl(outreq->request_seq); @@ -615,6 +548,15 @@ struct sctp_chunk *sctp_process_strreset_outreq( if (!(asoc->strreset_enable & SCTP_ENABLE_RESET_STREAM_REQ)) goto out; + nums = (ntohs(param.p->length) - sizeof(*outreq)) / sizeof(__u16); + str_p = outreq->list_of_streams; + for (i = 0; i < nums; i++) { + if (ntohs(str_p[i]) >= stream->incnt) { + result = SCTP_STRRESET_ERR_WRONG_SSN; + goto out; + } + } + if (asoc->strreset_chunk) { if (!sctp_chunk_lookup_strreset_param( asoc, outreq->response_seq, @@ -637,32 +579,19 @@ struct sctp_chunk *sctp_process_strreset_outreq( sctp_chunk_put(asoc->strreset_chunk); asoc->strreset_chunk = NULL; } - - flags = SCTP_STREAM_RESET_INCOMING_SSN; } - nums = (ntohs(param.p->length) - sizeof(*outreq)) / sizeof(__u16); - if (nums) { - str_p = outreq->list_of_streams; - for (i = 0; i < nums; i++) { - if (ntohs(str_p[i]) >= stream->incnt) { - result = SCTP_STRRESET_ERR_WRONG_SSN; - goto out; - } - } - + if (nums) for (i = 0; i < nums; i++) SCTP_SI(stream, ntohs(str_p[i]))->mid = 0; - } else { + else for (i = 0; i < stream->incnt; i++) SCTP_SI(stream, i)->mid = 0; - } result = SCTP_STRRESET_PERFORMED; *evp = sctp_ulpevent_make_stream_reset_event(asoc, - flags | SCTP_STREAM_RESET_OUTGOING_SSN, nums, str_p, - GFP_ATOMIC); + SCTP_STREAM_RESET_INCOMING_SSN, nums, str_p, GFP_ATOMIC); out: sctp_update_strreset_result(asoc, result); @@ -738,9 +667,6 @@ struct sctp_chunk *sctp_process_strreset_inreq( result = SCTP_STRRESET_PERFORMED; - *evp = sctp_ulpevent_make_stream_reset_event(asoc, - SCTP_STREAM_RESET_INCOMING_SSN, nums, str_p, GFP_ATOMIC); - out: sctp_update_strreset_result(asoc, result); err: @@ -873,6 +799,14 @@ struct sctp_chunk *sctp_process_strreset_addstrm_out( if (!(asoc->strreset_enable & SCTP_ENABLE_CHANGE_ASSOC_REQ)) goto out; + in = ntohs(addstrm->number_of_streams); + incnt = stream->incnt + in; + if (!in || incnt > SCTP_MAX_STREAM) + goto out; + + if (sctp_stream_alloc_in(stream, incnt, GFP_ATOMIC)) + goto out; + if (asoc->strreset_chunk) { if (!sctp_chunk_lookup_strreset_param( asoc, 0, SCTP_PARAM_RESET_ADD_IN_STREAMS)) { @@ -896,14 +830,6 @@ struct sctp_chunk *sctp_process_strreset_addstrm_out( } } - in = ntohs(addstrm->number_of_streams); - incnt = stream->incnt + in; - if (!in || incnt > SCTP_MAX_STREAM) - goto out; - - if (sctp_stream_alloc_in(stream, incnt, GFP_ATOMIC)) - goto out; - stream->incnt = incnt; result = SCTP_STRRESET_PERFORMED; @@ -973,9 +899,6 @@ struct sctp_chunk *sctp_process_strreset_addstrm_in( result = SCTP_STRRESET_PERFORMED; - *evp = sctp_ulpevent_make_stream_change_event(asoc, - 0, 0, ntohs(addstrm->number_of_streams), GFP_ATOMIC); - out: sctp_update_strreset_result(asoc, result); err: @@ -1036,10 +959,10 @@ struct sctp_chunk *sctp_process_strreset_resp( sout->mid_uo = 0; } } - - flags = SCTP_STREAM_RESET_OUTGOING_SSN; } + flags |= SCTP_STREAM_RESET_OUTGOING_SSN; + for (i = 0; i < stream->outcnt; i++) SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; @@ -1058,6 +981,8 @@ struct sctp_chunk *sctp_process_strreset_resp( nums = (ntohs(inreq->param_hdr.length) - sizeof(*inreq)) / sizeof(__u16); + flags |= SCTP_STREAM_RESET_INCOMING_SSN; + *evp = sctp_ulpevent_make_stream_reset_event(asoc, flags, nums, str_p, GFP_ATOMIC); } else if (req->type == SCTP_PARAM_RESET_TSN_REQUEST) { diff --git a/net/sctp/stream_interleave.c b/net/sctp/stream_interleave.c index a6bf21579466..102c6fefe38c 100644 --- a/net/sctp/stream_interleave.c +++ b/net/sctp/stream_interleave.c @@ -101,7 +101,7 @@ static void sctp_chunk_assign_mid(struct sctp_chunk *chunk) static bool sctp_validate_data(struct sctp_chunk *chunk) { - const struct sctp_stream *stream; + struct sctp_stream *stream; __u16 sid, ssn; if (chunk->chunk_hdr->type != SCTP_CID_DATA) diff --git a/net/sctp/transport.c b/net/sctp/transport.c index 033696e6f74f..ad158d311ffa 100644 --- a/net/sctp/transport.c +++ b/net/sctp/transport.c @@ -207,7 +207,8 @@ void sctp_transport_reset_hb_timer(struct sctp_transport *transport) /* When a data chunk is sent, reset the heartbeat interval. */ expires = jiffies + sctp_transport_timeout(transport); - if (time_before(transport->hb_timer.expires, expires) && + if ((time_before(transport->hb_timer.expires, expires) || + !timer_pending(&transport->hb_timer)) && !mod_timer(&transport->hb_timer, expires + prandom_u32_max(transport->rto))) sctp_transport_hold(transport); diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index c4e56602e0c6..77ef53596d18 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -30,6 +30,10 @@ #include <net/smc.h> #include <asm/ioctls.h> +#include <net/net_namespace.h> +#include <net/netns/generic.h> +#include "smc_netns.h" + #include "smc.h" #include "smc_clc.h" #include "smc_llc.h" @@ -42,8 +46,11 @@ #include "smc_rx.h" #include "smc_close.h" -static DEFINE_MUTEX(smc_create_lgr_pending); /* serialize link group - * creation +static DEFINE_MUTEX(smc_server_lgr_pending); /* serialize link group + * creation on server + */ +static DEFINE_MUTEX(smc_client_lgr_pending); /* serialize link group + * creation on client */ static void smc_tcp_listen_work(struct work_struct *); @@ -145,32 +152,33 @@ static int smc_release(struct socket *sock) rc = smc_close_active(smc); sock_set_flag(sk, SOCK_DEAD); sk->sk_shutdown |= SHUTDOWN_MASK; - } - - sk->sk_prot->unhash(sk); - - if (smc->clcsock) { - if (smc->use_fallback && sk->sk_state == SMC_LISTEN) { + } else { + if (sk->sk_state != SMC_LISTEN && sk->sk_state != SMC_INIT) + sock_put(sk); /* passive closing */ + if (sk->sk_state == SMC_LISTEN) { /* wake up clcsock accept */ rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR); } - mutex_lock(&smc->clcsock_release_lock); - sock_release(smc->clcsock); - smc->clcsock = NULL; - mutex_unlock(&smc->clcsock_release_lock); - } - if (smc->use_fallback) { - if (sk->sk_state != SMC_LISTEN && sk->sk_state != SMC_INIT) - sock_put(sk); /* passive closing */ sk->sk_state = SMC_CLOSED; sk->sk_state_change(sk); } + sk->sk_prot->unhash(sk); + + if (sk->sk_state == SMC_CLOSED) { + if (smc->clcsock) { + mutex_lock(&smc->clcsock_release_lock); + sock_release(smc->clcsock); + smc->clcsock = NULL; + mutex_unlock(&smc->clcsock_release_lock); + } + if (!smc->use_fallback) + smc_conn_free(&smc->conn); + } + /* detach socket */ sock_orphan(sk); sock->sk = NULL; - if (!smc->use_fallback && sk->sk_state == SMC_CLOSED) - smc_conn_free(&smc->conn); release_sock(sk); sock_put(sk); /* final sock_put */ @@ -291,7 +299,8 @@ static void smc_copy_sock_settings(struct sock *nsk, struct sock *osk, (1UL << SOCK_RXQ_OVFL) | \ (1UL << SOCK_WIFI_STATUS) | \ (1UL << SOCK_NOFCS) | \ - (1UL << SOCK_FILTER_LOCKED)) + (1UL << SOCK_FILTER_LOCKED) | \ + (1UL << SOCK_TSTAMP_NEW)) /* copy only relevant settings and flags of SOL_SOCKET level from smc to * clc socket (since smc is not called for these options from net/core) */ @@ -475,7 +484,12 @@ static int smc_connect_abort(struct smc_sock *smc, int reason_code, { if (local_contact == SMC_FIRST_CONTACT) smc_lgr_forget(smc->conn.lgr); - mutex_unlock(&smc_create_lgr_pending); + if (smc->conn.lgr->is_smcd) + /* there is only one lgr role for SMC-D; use server lock */ + mutex_unlock(&smc_server_lgr_pending); + else + mutex_unlock(&smc_client_lgr_pending); + smc_conn_free(&smc->conn); return reason_code; } @@ -560,7 +574,7 @@ static int smc_connect_rdma(struct smc_sock *smc, struct smc_link *link; int reason_code = 0; - mutex_lock(&smc_create_lgr_pending); + mutex_lock(&smc_client_lgr_pending); local_contact = smc_conn_create(smc, false, aclc->hdr.flag, ibdev, ibport, ntoh24(aclc->qpn), &aclc->lcl, NULL, 0); @@ -571,7 +585,8 @@ static int smc_connect_rdma(struct smc_sock *smc, reason_code = SMC_CLC_DECL_SYNCERR; /* synchr. error */ else reason_code = SMC_CLC_DECL_INTERR; /* other error */ - return smc_connect_abort(smc, reason_code, 0); + mutex_unlock(&smc_client_lgr_pending); + return reason_code; } link = &smc->conn.lgr->lnk[SMC_SINGLE_LINK]; @@ -615,7 +630,7 @@ static int smc_connect_rdma(struct smc_sock *smc, return smc_connect_abort(smc, reason_code, local_contact); } - mutex_unlock(&smc_create_lgr_pending); + mutex_unlock(&smc_client_lgr_pending); smc_copy_sock_settings_to_clc(smc); if (smc->sk.sk_state == SMC_INIT) @@ -632,11 +647,14 @@ static int smc_connect_ism(struct smc_sock *smc, int local_contact = SMC_FIRST_CONTACT; int rc = 0; - mutex_lock(&smc_create_lgr_pending); + /* there is only one lgr role for SMC-D; use server lock */ + mutex_lock(&smc_server_lgr_pending); local_contact = smc_conn_create(smc, true, aclc->hdr.flag, NULL, 0, 0, NULL, ismdev, aclc->gid); - if (local_contact < 0) - return smc_connect_abort(smc, SMC_CLC_DECL_MEM, 0); + if (local_contact < 0) { + mutex_unlock(&smc_server_lgr_pending); + return SMC_CLC_DECL_MEM; + } /* Create send and receive buffers */ if (smc_buf_create(smc, true)) @@ -650,7 +668,7 @@ static int smc_connect_ism(struct smc_sock *smc, rc = smc_clc_send_confirm(smc); if (rc) return smc_connect_abort(smc, rc, local_contact); - mutex_unlock(&smc_create_lgr_pending); + mutex_unlock(&smc_server_lgr_pending); smc_copy_sock_settings_to_clc(smc); if (smc->sk.sk_state == SMC_INIT) @@ -1249,7 +1267,7 @@ static void smc_listen_work(struct work_struct *work) return; } - mutex_lock(&smc_create_lgr_pending); + mutex_lock(&smc_server_lgr_pending); smc_close_init(new_smc); smc_rx_init(new_smc); smc_tx_init(new_smc); @@ -1271,7 +1289,7 @@ static void smc_listen_work(struct work_struct *work) &local_contact) || smc_listen_rdma_reg(new_smc, local_contact))) { /* SMC not supported, decline */ - mutex_unlock(&smc_create_lgr_pending); + mutex_unlock(&smc_server_lgr_pending); smc_listen_decline(new_smc, SMC_CLC_DECL_MODEUNSUPP, local_contact); return; @@ -1280,29 +1298,33 @@ static void smc_listen_work(struct work_struct *work) /* send SMC Accept CLC message */ rc = smc_clc_send_accept(new_smc, local_contact); if (rc) { - mutex_unlock(&smc_create_lgr_pending); + mutex_unlock(&smc_server_lgr_pending); smc_listen_decline(new_smc, rc, local_contact); return; } + /* SMC-D does not need this lock any more */ + if (ism_supported) + mutex_unlock(&smc_server_lgr_pending); + /* receive SMC Confirm CLC message */ reason_code = smc_clc_wait_msg(new_smc, &cclc, sizeof(cclc), SMC_CLC_CONFIRM, CLC_WAIT_TIME); if (reason_code) { - mutex_unlock(&smc_create_lgr_pending); + if (!ism_supported) + mutex_unlock(&smc_server_lgr_pending); smc_listen_decline(new_smc, reason_code, local_contact); return; } /* finish worker */ if (!ism_supported) { - if (smc_listen_rdma_finish(new_smc, &cclc, local_contact)) { - mutex_unlock(&smc_create_lgr_pending); + rc = smc_listen_rdma_finish(new_smc, &cclc, local_contact); + mutex_unlock(&smc_server_lgr_pending); + if (rc) return; - } } smc_conn_save_peer_info(new_smc, &cclc); - mutex_unlock(&smc_create_lgr_pending); smc_listen_out_connected(new_smc); } @@ -1505,6 +1527,11 @@ static int smc_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, smc = smc_sk(sk); lock_sock(sk); + if (sk->sk_state == SMC_CLOSED && (sk->sk_shutdown & RCV_SHUTDOWN)) { + /* socket was connected before, no more data to read */ + rc = 0; + goto out; + } if ((sk->sk_state == SMC_INIT) || (sk->sk_state == SMC_LISTEN) || (sk->sk_state == SMC_CLOSED)) @@ -1840,7 +1867,11 @@ static ssize_t smc_splice_read(struct socket *sock, loff_t *ppos, smc = smc_sk(sk); lock_sock(sk); - + if (sk->sk_state == SMC_CLOSED && (sk->sk_shutdown & RCV_SHUTDOWN)) { + /* socket was connected before, no more data to read */ + rc = 0; + goto out; + } if (sk->sk_state == SMC_INIT || sk->sk_state == SMC_LISTEN || sk->sk_state == SMC_CLOSED) @@ -1939,10 +1970,33 @@ static const struct net_proto_family smc_sock_family_ops = { .create = smc_create, }; +unsigned int smc_net_id; + +static __net_init int smc_net_init(struct net *net) +{ + return smc_pnet_net_init(net); +} + +static void __net_exit smc_net_exit(struct net *net) +{ + smc_pnet_net_exit(net); +} + +static struct pernet_operations smc_net_ops = { + .init = smc_net_init, + .exit = smc_net_exit, + .id = &smc_net_id, + .size = sizeof(struct smc_net), +}; + static int __init smc_init(void) { int rc; + rc = register_pernet_subsys(&smc_net_ops); + if (rc) + return rc; + rc = smc_pnet_init(); if (rc) return rc; @@ -2008,6 +2062,7 @@ static void __exit smc_exit(void) proto_unregister(&smc_proto6); proto_unregister(&smc_proto); smc_pnet_exit(); + unregister_pernet_subsys(&smc_net_ops); } module_init(smc_init); diff --git a/net/smc/smc.h b/net/smc/smc.h index 5721416d0605..adbdf195eb08 100644 --- a/net/smc/smc.h +++ b/net/smc/smc.h @@ -113,9 +113,9 @@ struct smc_host_cdc_msg { /* Connection Data Control message */ } __aligned(8); enum smc_urg_state { - SMC_URG_VALID, /* data present */ - SMC_URG_NOTYET, /* data pending */ - SMC_URG_READ /* data was already read */ + SMC_URG_VALID = 1, /* data present */ + SMC_URG_NOTYET = 2, /* data pending */ + SMC_URG_READ = 3, /* data was already read */ }; struct smc_connection { diff --git a/net/smc/smc_cdc.c b/net/smc/smc_cdc.c index db83332ac1c8..d0b0f4c865b4 100644 --- a/net/smc/smc_cdc.c +++ b/net/smc/smc_cdc.c @@ -21,13 +21,6 @@ /********************************** send *************************************/ -struct smc_cdc_tx_pend { - struct smc_connection *conn; /* socket connection */ - union smc_host_cursor cursor; /* tx sndbuf cursor sent */ - union smc_host_cursor p_cursor; /* rx RMBE cursor produced */ - u16 ctrl_seq; /* conn. tx sequence # */ -}; - /* handler for send/transmission completion of a CDC msg */ static void smc_cdc_tx_handler(struct smc_wr_tx_pend_priv *pnd_snd, struct smc_link *link, @@ -61,12 +54,14 @@ static void smc_cdc_tx_handler(struct smc_wr_tx_pend_priv *pnd_snd, int smc_cdc_get_free_slot(struct smc_connection *conn, struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wr_rdma_buf, struct smc_cdc_tx_pend **pend) { struct smc_link *link = &conn->lgr->lnk[SMC_SINGLE_LINK]; int rc; rc = smc_wr_tx_get_free_slot(link, smc_cdc_tx_handler, wr_buf, + wr_rdma_buf, (struct smc_wr_tx_pend_priv **)pend); if (!conn->alert_token_local) /* abnormal termination */ @@ -96,6 +91,7 @@ int smc_cdc_msg_send(struct smc_connection *conn, struct smc_wr_buf *wr_buf, struct smc_cdc_tx_pend *pend) { + union smc_host_cursor cfed; struct smc_link *link; int rc; @@ -105,12 +101,12 @@ int smc_cdc_msg_send(struct smc_connection *conn, conn->tx_cdc_seq++; conn->local_tx_ctrl.seqno = conn->tx_cdc_seq; - smc_host_msg_to_cdc((struct smc_cdc_msg *)wr_buf, - &conn->local_tx_ctrl, conn); + smc_host_msg_to_cdc((struct smc_cdc_msg *)wr_buf, conn, &cfed); rc = smc_wr_tx_send(link, (struct smc_wr_tx_pend_priv *)pend); - if (!rc) - smc_curs_copy(&conn->rx_curs_confirmed, - &conn->local_tx_ctrl.cons, conn); + if (!rc) { + smc_curs_copy(&conn->rx_curs_confirmed, &cfed, conn); + conn->local_rx_ctrl.prod_flags.cons_curs_upd_req = 0; + } return rc; } @@ -121,11 +117,14 @@ static int smcr_cdc_get_slot_and_msg_send(struct smc_connection *conn) struct smc_wr_buf *wr_buf; int rc; - rc = smc_cdc_get_free_slot(conn, &wr_buf, &pend); + rc = smc_cdc_get_free_slot(conn, &wr_buf, NULL, &pend); if (rc) return rc; - return smc_cdc_msg_send(conn, wr_buf, pend); + spin_lock_bh(&conn->send_lock); + rc = smc_cdc_msg_send(conn, wr_buf, pend); + spin_unlock_bh(&conn->send_lock); + return rc; } int smc_cdc_get_slot_and_msg_send(struct smc_connection *conn) @@ -195,6 +194,7 @@ int smcd_cdc_msg_send(struct smc_connection *conn) if (rc) return rc; smc_curs_copy(&conn->rx_curs_confirmed, &curs, conn); + conn->local_rx_ctrl.prod_flags.cons_curs_upd_req = 0; /* Calculate transmitted data and increment free send buffer space */ diff = smc_curs_diff(conn->sndbuf_desc->len, &conn->tx_curs_fin, &conn->tx_curs_sent); @@ -271,26 +271,18 @@ static void smc_cdc_msg_recv_action(struct smc_sock *smc, smp_mb__after_atomic(); smc->sk.sk_data_ready(&smc->sk); } else { - if (conn->local_rx_ctrl.prod_flags.write_blocked || - conn->local_rx_ctrl.prod_flags.cons_curs_upd_req || - conn->local_rx_ctrl.prod_flags.urg_data_pending) { - if (conn->local_rx_ctrl.prod_flags.urg_data_pending) - conn->urg_state = SMC_URG_NOTYET; - /* force immediate tx of current consumer cursor, but - * under send_lock to guarantee arrival in seqno-order - */ - if (smc->sk.sk_state != SMC_INIT) - smc_tx_sndbuf_nonempty(conn); - } + if (conn->local_rx_ctrl.prod_flags.write_blocked) + smc->sk.sk_data_ready(&smc->sk); + if (conn->local_rx_ctrl.prod_flags.urg_data_pending) + conn->urg_state = SMC_URG_NOTYET; } - /* piggy backed tx info */ /* trigger sndbuf consumer: RDMA write into peer RMBE and CDC */ - if (diff_cons && smc_tx_prepared_sends(conn)) { + if ((diff_cons && smc_tx_prepared_sends(conn)) || + conn->local_rx_ctrl.prod_flags.cons_curs_upd_req || + conn->local_rx_ctrl.prod_flags.urg_data_pending) smc_tx_sndbuf_nonempty(conn); - /* trigger socket release if connection closed */ - smc_close_wake_tx_prepared(smc); - } + if (diff_cons && conn->urg_tx_pend && atomic_read(&conn->peer_rmbe_space) == conn->peer_rmbe_size) { /* urg data confirmed by peer, indicate we're ready for more */ diff --git a/net/smc/smc_cdc.h b/net/smc/smc_cdc.h index b5bfe38c7f9b..861dc24c588c 100644 --- a/net/smc/smc_cdc.h +++ b/net/smc/smc_cdc.h @@ -160,7 +160,9 @@ static inline void smcd_curs_copy(union smcd_cdc_cursor *tgt, #endif } -/* calculate cursor difference between old and new, where old <= new */ +/* calculate cursor difference between old and new, where old <= new and + * difference cannot exceed size + */ static inline int smc_curs_diff(unsigned int size, union smc_host_cursor *old, union smc_host_cursor *new) @@ -185,28 +187,51 @@ static inline int smc_curs_comp(unsigned int size, return smc_curs_diff(size, old, new); } +/* calculate cursor difference between old and new, where old <= new and + * difference may exceed size + */ +static inline int smc_curs_diff_large(unsigned int size, + union smc_host_cursor *old, + union smc_host_cursor *new) +{ + if (old->wrap < new->wrap) + return min_t(int, + (size - old->count) + new->count + + (new->wrap - old->wrap - 1) * size, + size); + + if (old->wrap > new->wrap) /* wrap has switched from 0xffff to 0x0000 */ + return min_t(int, + (size - old->count) + new->count + + (new->wrap + 0xffff - old->wrap) * size, + size); + + return max_t(int, 0, (new->count - old->count)); +} + static inline void smc_host_cursor_to_cdc(union smc_cdc_cursor *peer, union smc_host_cursor *local, + union smc_host_cursor *save, struct smc_connection *conn) { - union smc_host_cursor temp; - - smc_curs_copy(&temp, local, conn); - peer->count = htonl(temp.count); - peer->wrap = htons(temp.wrap); + smc_curs_copy(save, local, conn); + peer->count = htonl(save->count); + peer->wrap = htons(save->wrap); /* peer->reserved = htons(0); must be ensured by caller */ } static inline void smc_host_msg_to_cdc(struct smc_cdc_msg *peer, - struct smc_host_cdc_msg *local, - struct smc_connection *conn) + struct smc_connection *conn, + union smc_host_cursor *save) { + struct smc_host_cdc_msg *local = &conn->local_tx_ctrl; + peer->common.type = local->common.type; peer->len = local->len; peer->seqno = htons(local->seqno); peer->token = htonl(local->token); - smc_host_cursor_to_cdc(&peer->prod, &local->prod, conn); - smc_host_cursor_to_cdc(&peer->cons, &local->cons, conn); + smc_host_cursor_to_cdc(&peer->prod, &local->prod, save, conn); + smc_host_cursor_to_cdc(&peer->cons, &local->cons, save, conn); peer->prod_flags = local->prod_flags; peer->conn_state_flags = local->conn_state_flags; } @@ -245,17 +270,18 @@ static inline void smcr_cdc_msg_to_host(struct smc_host_cdc_msg *local, } static inline void smcd_cdc_msg_to_host(struct smc_host_cdc_msg *local, - struct smcd_cdc_msg *peer) + struct smcd_cdc_msg *peer, + struct smc_connection *conn) { union smc_host_cursor temp; temp.wrap = peer->prod.wrap; temp.count = peer->prod.count; - atomic64_set(&local->prod.acurs, atomic64_read(&temp.acurs)); + smc_curs_copy(&local->prod, &temp, conn); temp.wrap = peer->cons.wrap; temp.count = peer->cons.count; - atomic64_set(&local->cons.acurs, atomic64_read(&temp.acurs)); + smc_curs_copy(&local->cons, &temp, conn); local->prod_flags = peer->cons.prod_flags; local->conn_state_flags = peer->cons.conn_state_flags; } @@ -265,15 +291,21 @@ static inline void smc_cdc_msg_to_host(struct smc_host_cdc_msg *local, struct smc_connection *conn) { if (conn->lgr->is_smcd) - smcd_cdc_msg_to_host(local, (struct smcd_cdc_msg *)peer); + smcd_cdc_msg_to_host(local, (struct smcd_cdc_msg *)peer, conn); else smcr_cdc_msg_to_host(local, peer, conn); } -struct smc_cdc_tx_pend; +struct smc_cdc_tx_pend { + struct smc_connection *conn; /* socket connection */ + union smc_host_cursor cursor; /* tx sndbuf cursor sent */ + union smc_host_cursor p_cursor; /* rx RMBE cursor produced */ + u16 ctrl_seq; /* conn. tx sequence # */ +}; int smc_cdc_get_free_slot(struct smc_connection *conn, struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wr_rdma_buf, struct smc_cdc_tx_pend **pend); void smc_cdc_tx_dismiss_slots(struct smc_connection *conn); int smc_cdc_msg_send(struct smc_connection *conn, struct smc_wr_buf *wr_buf, diff --git a/net/smc/smc_clc.c b/net/smc/smc_clc.c index 776e9dfc915d..d53fd588d1f5 100644 --- a/net/smc/smc_clc.c +++ b/net/smc/smc_clc.c @@ -378,7 +378,7 @@ int smc_clc_send_decline(struct smc_sock *smc, u32 peer_diag_info) vec.iov_len = sizeof(struct smc_clc_msg_decline); len = kernel_sendmsg(smc->clcsock, &msg, &vec, 1, sizeof(struct smc_clc_msg_decline)); - if (len < sizeof(struct smc_clc_msg_decline)) + if (len < 0 || len < sizeof(struct smc_clc_msg_decline)) len = -EPROTO; return len > 0 ? 0 : len; } diff --git a/net/smc/smc_close.c b/net/smc/smc_close.c index ea2b87f29469..2ad37e998509 100644 --- a/net/smc/smc_close.c +++ b/net/smc/smc_close.c @@ -345,14 +345,7 @@ static void smc_close_passive_work(struct work_struct *work) switch (sk->sk_state) { case SMC_INIT: - if (atomic_read(&conn->bytes_to_rcv) || - (rxflags->peer_done_writing && - !smc_cdc_rxed_any_close(conn))) { - sk->sk_state = SMC_APPCLOSEWAIT1; - } else { - sk->sk_state = SMC_CLOSED; - sock_put(sk); /* passive closing */ - } + sk->sk_state = SMC_APPCLOSEWAIT1; break; case SMC_ACTIVE: sk->sk_state = SMC_APPCLOSEWAIT1; @@ -405,8 +398,13 @@ wakeup: if (old_state != sk->sk_state) { sk->sk_state_change(sk); if ((sk->sk_state == SMC_CLOSED) && - (sock_flag(sk, SOCK_DEAD) || !sk->sk_socket)) + (sock_flag(sk, SOCK_DEAD) || !sk->sk_socket)) { smc_conn_free(conn); + if (smc->clcsock) { + sock_release(smc->clcsock); + smc->clcsock = NULL; + } + } } release_sock(sk); sock_put(sk); /* sock_hold done by schedulers of close_work */ diff --git a/net/smc/smc_core.c b/net/smc/smc_core.c index 35c1cdc93e1c..53a17cfa61af 100644 --- a/net/smc/smc_core.c +++ b/net/smc/smc_core.c @@ -118,7 +118,6 @@ static void __smc_lgr_unregister_conn(struct smc_connection *conn) rb_erase(&conn->alert_node, &lgr->conns_all); lgr->conns_num--; conn->alert_token_local = 0; - conn->lgr = NULL; sock_put(&smc->sk); /* sock_hold in smc_lgr_register_conn() */ } @@ -128,6 +127,8 @@ static void smc_lgr_unregister_conn(struct smc_connection *conn) { struct smc_link_group *lgr = conn->lgr; + if (!lgr) + return; write_lock_bh(&lgr->conns_lock); if (conn->alert_token_local) { __smc_lgr_unregister_conn(conn); @@ -159,8 +160,6 @@ static void smc_lgr_free_work(struct work_struct *work) bool conns; spin_lock_bh(&smc_lgr_list.lock); - if (list_empty(&lgr->list)) - goto free; read_lock_bh(&lgr->conns_lock); conns = RB_EMPTY_ROOT(&lgr->conns_all); read_unlock_bh(&lgr->conns_lock); @@ -168,8 +167,8 @@ static void smc_lgr_free_work(struct work_struct *work) spin_unlock_bh(&smc_lgr_list.lock); return; } - list_del_init(&lgr->list); /* remove from smc_lgr_list */ -free: + if (!list_empty(&lgr->list)) + list_del_init(&lgr->list); /* remove from smc_lgr_list */ spin_unlock_bh(&smc_lgr_list.lock); if (!lgr->is_smcd && !lgr->terminating) { @@ -300,13 +299,13 @@ static void smc_buf_unuse(struct smc_connection *conn, conn->sndbuf_desc->used = 0; if (conn->rmb_desc) { if (!conn->rmb_desc->regerr) { - conn->rmb_desc->used = 0; if (!lgr->is_smcd) { /* unregister rmb with peer */ smc_llc_do_delete_rkey( &lgr->lnk[SMC_SINGLE_LINK], conn->rmb_desc); } + conn->rmb_desc->used = 0; } else { /* buf registration failed, reuse not possible */ write_lock_bh(&lgr->rmbs_lock); @@ -331,8 +330,9 @@ void smc_conn_free(struct smc_connection *conn) } else { smc_cdc_tx_dismiss_slots(conn); } - smc_lgr_unregister_conn(conn); /* unsets conn->lgr */ + smc_lgr_unregister_conn(conn); smc_buf_unuse(conn, lgr); /* allow buffer reuse */ + conn->lgr = NULL; if (!lgr->conns_num) smc_lgr_schedule_free_work(lgr); @@ -462,6 +462,7 @@ static void __smc_lgr_terminate(struct smc_link_group *lgr) sock_hold(&smc->sk); /* sock_put in close work */ conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; __smc_lgr_unregister_conn(conn); + conn->lgr = NULL; write_unlock_bh(&lgr->conns_lock); if (!schedule_work(&conn->close_work)) sock_put(&smc->sk); @@ -628,6 +629,8 @@ int smc_conn_create(struct smc_sock *smc, bool is_smcd, int srv_first_contact, local_contact = SMC_REUSE_CONTACT; conn->lgr = lgr; smc_lgr_register_conn(conn); /* add smc conn to lgr */ + if (delayed_work_pending(&lgr->free_work)) + cancel_delayed_work(&lgr->free_work); write_unlock_bh(&lgr->conns_lock); break; } diff --git a/net/smc/smc_core.h b/net/smc/smc_core.h index b00287989a3d..8806d2afa6ed 100644 --- a/net/smc/smc_core.h +++ b/net/smc/smc_core.h @@ -52,6 +52,24 @@ enum smc_wr_reg_state { FAILED /* ib_wr_reg_mr response: failure */ }; +struct smc_rdma_sge { /* sges for RDMA writes */ + struct ib_sge wr_tx_rdma_sge[SMC_IB_MAX_SEND_SGE]; +}; + +#define SMC_MAX_RDMA_WRITES 2 /* max. # of RDMA writes per + * message send + */ + +struct smc_rdma_sges { /* sges per message send */ + struct smc_rdma_sge tx_rdma_sge[SMC_MAX_RDMA_WRITES]; +}; + +struct smc_rdma_wr { /* work requests per message + * send + */ + struct ib_rdma_wr wr_tx_rdma[SMC_MAX_RDMA_WRITES]; +}; + struct smc_link { struct smc_ib_device *smcibdev; /* ib-device */ u8 ibport; /* port - values 1 | 2 */ @@ -64,6 +82,8 @@ struct smc_link { struct smc_wr_buf *wr_tx_bufs; /* WR send payload buffers */ struct ib_send_wr *wr_tx_ibs; /* WR send meta data */ struct ib_sge *wr_tx_sges; /* WR send gather meta data */ + struct smc_rdma_sges *wr_tx_rdma_sges;/*RDMA WRITE gather meta data*/ + struct smc_rdma_wr *wr_tx_rdmas; /* WR RDMA WRITE */ struct smc_wr_tx_pend *wr_tx_pends; /* WR send waiting for CQE */ /* above four vectors have wr_tx_cnt elements and use the same index */ dma_addr_t wr_tx_dma_addr; /* DMA address of wr_tx_bufs */ diff --git a/net/smc/smc_diag.c b/net/smc/smc_diag.c index dbf64a93d68a..371b4cf31fcd 100644 --- a/net/smc/smc_diag.c +++ b/net/smc/smc_diag.c @@ -38,6 +38,7 @@ static void smc_diag_msg_common_fill(struct smc_diag_msg *r, struct sock *sk) { struct smc_sock *smc = smc_sk(sk); + r->diag_family = sk->sk_family; if (!smc->clcsock) return; r->id.idiag_sport = htons(smc->clcsock->sk->sk_num); @@ -45,14 +46,12 @@ static void smc_diag_msg_common_fill(struct smc_diag_msg *r, struct sock *sk) r->id.idiag_if = smc->clcsock->sk->sk_bound_dev_if; sock_diag_save_cookie(sk, r->id.idiag_cookie); if (sk->sk_protocol == SMCPROTO_SMC) { - r->diag_family = PF_INET; memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); r->id.idiag_src[0] = smc->clcsock->sk->sk_rcv_saddr; r->id.idiag_dst[0] = smc->clcsock->sk->sk_daddr; #if IS_ENABLED(CONFIG_IPV6) } else if (sk->sk_protocol == SMCPROTO_SMC6) { - r->diag_family = PF_INET6; memcpy(&r->id.idiag_src, &smc->clcsock->sk->sk_v6_rcv_saddr, sizeof(smc->clcsock->sk->sk_v6_rcv_saddr)); memcpy(&r->id.idiag_dst, &smc->clcsock->sk->sk_v6_daddr, diff --git a/net/smc/smc_ib.c b/net/smc/smc_ib.c index e519ef29c0ff..53f429c04843 100644 --- a/net/smc/smc_ib.c +++ b/net/smc/smc_ib.c @@ -257,12 +257,20 @@ static void smc_ib_global_event_handler(struct ib_event_handler *handler, smcibdev = container_of(handler, struct smc_ib_device, event_handler); switch (ibevent->event) { - case IB_EVENT_PORT_ERR: case IB_EVENT_DEVICE_FATAL: + /* terminate all ports on device */ + for (port_idx = 0; port_idx < SMC_MAX_PORTS; port_idx++) + set_bit(port_idx, &smcibdev->port_event_mask); + schedule_work(&smcibdev->port_event_work); + break; + case IB_EVENT_PORT_ERR: case IB_EVENT_PORT_ACTIVE: + case IB_EVENT_GID_CHANGE: port_idx = ibevent->element.port_num - 1; - set_bit(port_idx, &smcibdev->port_event_mask); - schedule_work(&smcibdev->port_event_work); + if (port_idx < SMC_MAX_PORTS) { + set_bit(port_idx, &smcibdev->port_event_mask); + schedule_work(&smcibdev->port_event_work); + } break; default: break; @@ -289,18 +297,18 @@ int smc_ib_create_protection_domain(struct smc_link *lnk) static void smc_ib_qp_event_handler(struct ib_event *ibevent, void *priv) { - struct smc_ib_device *smcibdev = - (struct smc_ib_device *)ibevent->device; + struct smc_link *lnk = (struct smc_link *)priv; + struct smc_ib_device *smcibdev = lnk->smcibdev; u8 port_idx; switch (ibevent->event) { - case IB_EVENT_DEVICE_FATAL: - case IB_EVENT_GID_CHANGE: - case IB_EVENT_PORT_ERR: + case IB_EVENT_QP_FATAL: case IB_EVENT_QP_ACCESS_ERR: - port_idx = ibevent->element.port_num - 1; - set_bit(port_idx, &smcibdev->port_event_mask); - schedule_work(&smcibdev->port_event_work); + port_idx = ibevent->element.qp->port - 1; + if (port_idx < SMC_MAX_PORTS) { + set_bit(port_idx, &smcibdev->port_event_mask); + schedule_work(&smcibdev->port_event_work); + } break; default: break; @@ -556,7 +564,6 @@ static void smc_ib_remove_dev(struct ib_device *ibdev, void *client_data) spin_lock(&smc_ib_devices.lock); list_del_init(&smcibdev->list); /* remove from smc_ib_devices */ spin_unlock(&smc_ib_devices.lock); - smc_pnet_remove_by_ibdev(smcibdev); smc_ib_cleanup_per_ibdev(smcibdev); ib_unregister_event_handler(&smcibdev->event_handler); kfree(smcibdev); diff --git a/net/smc/smc_ib.h b/net/smc/smc_ib.h index bac7fd65a4c0..da60ab9e8d70 100644 --- a/net/smc/smc_ib.h +++ b/net/smc/smc_ib.h @@ -42,6 +42,8 @@ struct smc_ib_device { /* ib-device infos for smc */ /* mac address per port*/ u8 pnetid[SMC_MAX_PORTS][SMC_MAX_PNETID_LEN]; /* pnetid per port */ + bool pnetid_by_user[SMC_MAX_PORTS]; + /* pnetid defined by user? */ u8 initialized : 1; /* ib dev CQ, evthdl done */ struct work_struct port_event_work; unsigned long port_event_mask; diff --git a/net/smc/smc_llc.c b/net/smc/smc_llc.c index a6d3623d06f4..4fd60c522802 100644 --- a/net/smc/smc_llc.c +++ b/net/smc/smc_llc.c @@ -166,7 +166,8 @@ static int smc_llc_add_pending_send(struct smc_link *link, { int rc; - rc = smc_wr_tx_get_free_slot(link, smc_llc_tx_handler, wr_buf, pend); + rc = smc_wr_tx_get_free_slot(link, smc_llc_tx_handler, wr_buf, NULL, + pend); if (rc < 0) return rc; BUILD_BUG_ON_MSG( diff --git a/net/smc/smc_netns.h b/net/smc/smc_netns.h new file mode 100644 index 000000000000..e7a8fc4ae02f --- /dev/null +++ b/net/smc/smc_netns.h @@ -0,0 +1,20 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Shared Memory Communications + * + * Network namespace definitions. + * + * Copyright IBM Corp. 2018 + */ + +#ifndef SMC_NETNS_H +#define SMC_NETNS_H + +#include "smc_pnet.h" + +extern unsigned int smc_net_id; + +/* per-network namespace private data */ +struct smc_net { + struct smc_pnettable pnettable; +}; +#endif diff --git a/net/smc/smc_pnet.c b/net/smc/smc_pnet.c index 7cb3e4f07c10..8d2f6296279c 100644 --- a/net/smc/smc_pnet.c +++ b/net/smc/smc_pnet.c @@ -20,14 +20,21 @@ #include <rdma/ib_verbs.h> +#include <net/netns/generic.h> +#include "smc_netns.h" + #include "smc_pnet.h" #include "smc_ib.h" #include "smc_ism.h" +#define SMC_ASCII_BLANK 32 + +static struct net_device *pnet_find_base_ndev(struct net_device *ndev); + static struct nla_policy smc_pnet_policy[SMC_PNETID_MAX + 1] = { [SMC_PNETID_NAME] = { .type = NLA_NUL_STRING, - .len = SMC_MAX_PNETID_LEN - 1 + .len = SMC_MAX_PNETID_LEN }, [SMC_PNETID_ETHNAME] = { .type = NLA_NUL_STRING, @@ -43,118 +50,127 @@ static struct nla_policy smc_pnet_policy[SMC_PNETID_MAX + 1] = { static struct genl_family smc_pnet_nl_family; /** - * struct smc_pnettable - SMC PNET table anchor - * @lock: Lock for list action - * @pnetlist: List of PNETIDs - */ -static struct smc_pnettable { - rwlock_t lock; - struct list_head pnetlist; -} smc_pnettable = { - .pnetlist = LIST_HEAD_INIT(smc_pnettable.pnetlist), - .lock = __RW_LOCK_UNLOCKED(smc_pnettable.lock) -}; - -/** - * struct smc_pnetentry - pnet identifier name entry + * struct smc_user_pnetentry - pnet identifier name entry for/from user * @list: List node. * @pnet_name: Pnet identifier name * @ndev: pointer to network device. * @smcibdev: Pointer to IB device. + * @ib_port: Port of IB device. + * @smcd_dev: Pointer to smcd device. */ -struct smc_pnetentry { +struct smc_user_pnetentry { struct list_head list; char pnet_name[SMC_MAX_PNETID_LEN + 1]; struct net_device *ndev; struct smc_ib_device *smcibdev; u8 ib_port; + struct smcd_dev *smcd_dev; }; -/* Check if two RDMA device entries are identical. Use device name and port - * number for comparison. - */ -static bool smc_pnet_same_ibname(struct smc_pnetentry *pnetelem, char *ibname, - u8 ibport) -{ - return pnetelem->ib_port == ibport && - !strncmp(pnetelem->smcibdev->ibdev->name, ibname, - sizeof(pnetelem->smcibdev->ibdev->name)); -} +/* pnet entry stored in pnet table */ +struct smc_pnetentry { + struct list_head list; + char pnet_name[SMC_MAX_PNETID_LEN + 1]; + struct net_device *ndev; +}; -/* Find a pnetid in the pnet table. - */ -static struct smc_pnetentry *smc_pnet_find_pnetid(char *pnet_name) +/* Check if two given pnetids match */ +static bool smc_pnet_match(u8 *pnetid1, u8 *pnetid2) { - struct smc_pnetentry *pnetelem, *found_pnetelem = NULL; + int i; - read_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { - if (!strncmp(pnetelem->pnet_name, pnet_name, - sizeof(pnetelem->pnet_name))) { - found_pnetelem = pnetelem; + for (i = 0; i < SMC_MAX_PNETID_LEN; i++) { + if ((pnetid1[i] == 0 || pnetid1[i] == SMC_ASCII_BLANK) && + (pnetid2[i] == 0 || pnetid2[i] == SMC_ASCII_BLANK)) break; - } + if (pnetid1[i] != pnetid2[i]) + return false; } - read_unlock(&smc_pnettable.lock); - return found_pnetelem; + return true; } /* Remove a pnetid from the pnet table. */ -static int smc_pnet_remove_by_pnetid(char *pnet_name) +static int smc_pnet_remove_by_pnetid(struct net *net, char *pnet_name) { struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; + struct smc_ib_device *ibdev; + struct smcd_dev *smcd_dev; + struct smc_net *sn; int rc = -ENOENT; + int ibport; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; - write_lock(&smc_pnettable.lock); - list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist, + /* remove netdevices */ + write_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) { - if (!strncmp(pnetelem->pnet_name, pnet_name, - sizeof(pnetelem->pnet_name))) { + if (!pnet_name || + smc_pnet_match(pnetelem->pnet_name, pnet_name)) { list_del(&pnetelem->list); dev_put(pnetelem->ndev); kfree(pnetelem); rc = 0; - break; } } - write_unlock(&smc_pnettable.lock); - return rc; -} + write_unlock(&pnettable->lock); -/* Remove a pnet entry mentioning a given network device from the pnet table. - */ -static int smc_pnet_remove_by_ndev(struct net_device *ndev) -{ - struct smc_pnetentry *pnetelem, *tmp_pe; - int rc = -ENOENT; + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return rc; - write_lock(&smc_pnettable.lock); - list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist, - list) { - if (pnetelem->ndev == ndev) { - list_del(&pnetelem->list); - dev_put(pnetelem->ndev); - kfree(pnetelem); + /* remove ib devices */ + spin_lock(&smc_ib_devices.lock); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + for (ibport = 0; ibport < SMC_MAX_PORTS; ibport++) { + if (ibdev->pnetid_by_user[ibport] && + (!pnet_name || + smc_pnet_match(pnet_name, + ibdev->pnetid[ibport]))) { + memset(ibdev->pnetid[ibport], 0, + SMC_MAX_PNETID_LEN); + ibdev->pnetid_by_user[ibport] = false; + rc = 0; + } + } + } + spin_unlock(&smc_ib_devices.lock); + /* remove smcd devices */ + spin_lock(&smcd_dev_list.lock); + list_for_each_entry(smcd_dev, &smcd_dev_list.list, list) { + if (smcd_dev->pnetid_by_user && + (!pnet_name || + smc_pnet_match(pnet_name, smcd_dev->pnetid))) { + memset(smcd_dev->pnetid, 0, SMC_MAX_PNETID_LEN); + smcd_dev->pnetid_by_user = false; rc = 0; - break; } } - write_unlock(&smc_pnettable.lock); + spin_unlock(&smcd_dev_list.lock); return rc; } -/* Remove a pnet entry mentioning a given ib device from the pnet table. +/* Remove a pnet entry mentioning a given network device from the pnet table. */ -int smc_pnet_remove_by_ibdev(struct smc_ib_device *ibdev) +static int smc_pnet_remove_by_ndev(struct net_device *ndev) { struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); + struct smc_net *sn; int rc = -ENOENT; - write_lock(&smc_pnettable.lock); - list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist, - list) { - if (pnetelem->smcibdev == ibdev) { + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + write_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) { + if (pnetelem->ndev == ndev) { list_del(&pnetelem->list); dev_put(pnetelem->ndev); kfree(pnetelem); @@ -162,35 +178,84 @@ int smc_pnet_remove_by_ibdev(struct smc_ib_device *ibdev) break; } } - write_unlock(&smc_pnettable.lock); + write_unlock(&pnettable->lock); return rc; } /* Append a pnetid to the end of the pnet table if not already on this list. */ -static int smc_pnet_enter(struct smc_pnetentry *new_pnetelem) +static int smc_pnet_enter(struct smc_pnettable *pnettable, + struct smc_user_pnetentry *new_pnetelem) { + u8 pnet_null[SMC_MAX_PNETID_LEN] = {0}; + u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; + struct smc_pnetentry *tmp_pnetelem; struct smc_pnetentry *pnetelem; - int rc = -EEXIST; - - write_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { - if (!strncmp(pnetelem->pnet_name, new_pnetelem->pnet_name, - sizeof(new_pnetelem->pnet_name)) || - !strncmp(pnetelem->ndev->name, new_pnetelem->ndev->name, - sizeof(new_pnetelem->ndev->name)) || - smc_pnet_same_ibname(pnetelem, - new_pnetelem->smcibdev->ibdev->name, - new_pnetelem->ib_port)) { - dev_put(pnetelem->ndev); - goto found; + bool new_smcddev = false; + struct net_device *ndev; + bool new_netdev = true; + bool new_ibdev = false; + + if (new_pnetelem->smcibdev) { + struct smc_ib_device *ib_dev = new_pnetelem->smcibdev; + int ib_port = new_pnetelem->ib_port; + + spin_lock(&smc_ib_devices.lock); + if (smc_pnet_match(ib_dev->pnetid[ib_port - 1], pnet_null)) { + memcpy(ib_dev->pnetid[ib_port - 1], + new_pnetelem->pnet_name, SMC_MAX_PNETID_LEN); + ib_dev->pnetid_by_user[ib_port - 1] = true; + new_ibdev = true; } + spin_unlock(&smc_ib_devices.lock); } - list_add_tail(&new_pnetelem->list, &smc_pnettable.pnetlist); - rc = 0; -found: - write_unlock(&smc_pnettable.lock); - return rc; + if (new_pnetelem->smcd_dev) { + struct smcd_dev *smcd_dev = new_pnetelem->smcd_dev; + + spin_lock(&smcd_dev_list.lock); + if (smc_pnet_match(smcd_dev->pnetid, pnet_null)) { + memcpy(smcd_dev->pnetid, new_pnetelem->pnet_name, + SMC_MAX_PNETID_LEN); + smcd_dev->pnetid_by_user = true; + new_smcddev = true; + } + spin_unlock(&smcd_dev_list.lock); + } + + if (!new_pnetelem->ndev) + return (new_ibdev || new_smcddev) ? 0 : -EEXIST; + + /* check if (base) netdev already has a pnetid. If there is one, we do + * not want to add a pnet table entry + */ + ndev = pnet_find_base_ndev(new_pnetelem->ndev); + if (!smc_pnetid_by_dev_port(ndev->dev.parent, ndev->dev_port, + ndev_pnetid)) + return (new_ibdev || new_smcddev) ? 0 : -EEXIST; + + /* add a new netdev entry to the pnet table if there isn't one */ + tmp_pnetelem = kzalloc(sizeof(*pnetelem), GFP_KERNEL); + if (!tmp_pnetelem) + return -ENOMEM; + memcpy(tmp_pnetelem->pnet_name, new_pnetelem->pnet_name, + SMC_MAX_PNETID_LEN); + tmp_pnetelem->ndev = new_pnetelem->ndev; + + write_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { + if (pnetelem->ndev == new_pnetelem->ndev) + new_netdev = false; + } + if (new_netdev) { + dev_hold(tmp_pnetelem->ndev); + list_add_tail(&tmp_pnetelem->list, &pnettable->pnetlist); + write_unlock(&pnettable->lock); + } else { + write_unlock(&pnettable->lock); + kfree(tmp_pnetelem); + } + + return (new_netdev || new_ibdev || new_smcddev) ? 0 : -EEXIST; } /* The limit for pnetid is 16 characters. @@ -228,7 +293,9 @@ static struct smc_ib_device *smc_pnet_find_ib(char *ib_name) spin_lock(&smc_ib_devices.lock); list_for_each_entry(ibdev, &smc_ib_devices.list, list) { if (!strncmp(ibdev->ibdev->name, ib_name, - sizeof(ibdev->ibdev->name))) { + sizeof(ibdev->ibdev->name)) || + !strncmp(dev_name(ibdev->ibdev->dev.parent), ib_name, + IB_DEVICE_NAME_MAX - 1)) { goto out; } } @@ -238,10 +305,28 @@ out: return ibdev; } +/* Find an smcd device by a given name. The device might not exist. */ +static struct smcd_dev *smc_pnet_find_smcd(char *smcd_name) +{ + struct smcd_dev *smcd_dev; + + spin_lock(&smcd_dev_list.lock); + list_for_each_entry(smcd_dev, &smcd_dev_list.list, list) { + if (!strncmp(dev_name(&smcd_dev->dev), smcd_name, + IB_DEVICE_NAME_MAX - 1)) + goto out; + } + smcd_dev = NULL; +out: + spin_unlock(&smcd_dev_list.lock); + return smcd_dev; +} + /* Parse the supplied netlink attributes and fill a pnetentry structure. * For ethernet and infiniband device names verify that the devices exist. */ -static int smc_pnet_fill_entry(struct net *net, struct smc_pnetentry *pnetelem, +static int smc_pnet_fill_entry(struct net *net, + struct smc_user_pnetentry *pnetelem, struct nlattr *tb[]) { char *string, *ibname; @@ -258,30 +343,34 @@ static int smc_pnet_fill_entry(struct net *net, struct smc_pnetentry *pnetelem, goto error; rc = -EINVAL; - if (!tb[SMC_PNETID_ETHNAME]) - goto error; - rc = -ENOENT; - string = (char *)nla_data(tb[SMC_PNETID_ETHNAME]); - pnetelem->ndev = dev_get_by_name(net, string); - if (!pnetelem->ndev) - goto error; + if (tb[SMC_PNETID_ETHNAME]) { + string = (char *)nla_data(tb[SMC_PNETID_ETHNAME]); + pnetelem->ndev = dev_get_by_name(net, string); + if (!pnetelem->ndev) + goto error; + } - rc = -EINVAL; - if (!tb[SMC_PNETID_IBNAME]) - goto error; - rc = -ENOENT; - ibname = (char *)nla_data(tb[SMC_PNETID_IBNAME]); - ibname = strim(ibname); - pnetelem->smcibdev = smc_pnet_find_ib(ibname); - if (!pnetelem->smcibdev) - goto error; + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return 0; rc = -EINVAL; - if (!tb[SMC_PNETID_IBPORT]) - goto error; - pnetelem->ib_port = nla_get_u8(tb[SMC_PNETID_IBPORT]); - if (pnetelem->ib_port < 1 || pnetelem->ib_port > SMC_MAX_PORTS) - goto error; + if (tb[SMC_PNETID_IBNAME]) { + ibname = (char *)nla_data(tb[SMC_PNETID_IBNAME]); + ibname = strim(ibname); + pnetelem->smcibdev = smc_pnet_find_ib(ibname); + pnetelem->smcd_dev = smc_pnet_find_smcd(ibname); + if (!pnetelem->smcibdev && !pnetelem->smcd_dev) + goto error; + if (pnetelem->smcibdev) { + if (!tb[SMC_PNETID_IBPORT]) + goto error; + pnetelem->ib_port = nla_get_u8(tb[SMC_PNETID_IBPORT]); + if (pnetelem->ib_port < 1 || + pnetelem->ib_port > SMC_MAX_PORTS) + goto error; + } + } return 0; @@ -292,79 +381,65 @@ error: } /* Convert an smc_pnetentry to a netlink attribute sequence */ -static int smc_pnet_set_nla(struct sk_buff *msg, struct smc_pnetentry *pnetelem) +static int smc_pnet_set_nla(struct sk_buff *msg, + struct smc_user_pnetentry *pnetelem) { - if (nla_put_string(msg, SMC_PNETID_NAME, pnetelem->pnet_name) || - nla_put_string(msg, SMC_PNETID_ETHNAME, pnetelem->ndev->name) || - nla_put_string(msg, SMC_PNETID_IBNAME, - pnetelem->smcibdev->ibdev->name) || - nla_put_u8(msg, SMC_PNETID_IBPORT, pnetelem->ib_port)) + if (nla_put_string(msg, SMC_PNETID_NAME, pnetelem->pnet_name)) return -1; - return 0; -} - -/* Retrieve one PNETID entry */ -static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info) -{ - struct smc_pnetentry *pnetelem; - struct sk_buff *msg; - void *hdr; - int rc; - - if (!info->attrs[SMC_PNETID_NAME]) - return -EINVAL; - pnetelem = smc_pnet_find_pnetid( - (char *)nla_data(info->attrs[SMC_PNETID_NAME])); - if (!pnetelem) - return -ENOENT; - msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); - if (!msg) - return -ENOMEM; - - hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, - &smc_pnet_nl_family, 0, SMC_PNETID_GET); - if (!hdr) { - rc = -EMSGSIZE; - goto err_out; + if (pnetelem->ndev) { + if (nla_put_string(msg, SMC_PNETID_ETHNAME, + pnetelem->ndev->name)) + return -1; + } else { + if (nla_put_string(msg, SMC_PNETID_ETHNAME, "n/a")) + return -1; } - - if (smc_pnet_set_nla(msg, pnetelem)) { - rc = -ENOBUFS; - goto err_out; + if (pnetelem->smcibdev) { + if (nla_put_string(msg, SMC_PNETID_IBNAME, + dev_name(pnetelem->smcibdev->ibdev->dev.parent)) || + nla_put_u8(msg, SMC_PNETID_IBPORT, pnetelem->ib_port)) + return -1; + } else if (pnetelem->smcd_dev) { + if (nla_put_string(msg, SMC_PNETID_IBNAME, + dev_name(&pnetelem->smcd_dev->dev)) || + nla_put_u8(msg, SMC_PNETID_IBPORT, 1)) + return -1; + } else { + if (nla_put_string(msg, SMC_PNETID_IBNAME, "n/a") || + nla_put_u8(msg, SMC_PNETID_IBPORT, 0xff)) + return -1; } - genlmsg_end(msg, hdr); - return genlmsg_reply(msg, info); - -err_out: - nlmsg_free(msg); - return rc; + return 0; } static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info) { struct net *net = genl_info_net(info); - struct smc_pnetentry *pnetelem; + struct smc_user_pnetentry pnetelem; + struct smc_pnettable *pnettable; + struct smc_net *sn; int rc; - pnetelem = kzalloc(sizeof(*pnetelem), GFP_KERNEL); - if (!pnetelem) - return -ENOMEM; - rc = smc_pnet_fill_entry(net, pnetelem, info->attrs); + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + rc = smc_pnet_fill_entry(net, &pnetelem, info->attrs); if (!rc) - rc = smc_pnet_enter(pnetelem); - if (rc) { - kfree(pnetelem); - return rc; - } + rc = smc_pnet_enter(pnettable, &pnetelem); + if (pnetelem.ndev) + dev_put(pnetelem.ndev); return rc; } static int smc_pnet_del(struct sk_buff *skb, struct genl_info *info) { + struct net *net = genl_info_net(info); + if (!info->attrs[SMC_PNETID_NAME]) return -EINVAL; - return smc_pnet_remove_by_pnetid( + return smc_pnet_remove_by_pnetid(net, (char *)nla_data(info->attrs[SMC_PNETID_NAME])); } @@ -376,7 +451,7 @@ static int smc_pnet_dump_start(struct netlink_callback *cb) static int smc_pnet_dumpinfo(struct sk_buff *skb, u32 portid, u32 seq, u32 flags, - struct smc_pnetentry *pnetelem) + struct smc_user_pnetentry *pnetelem) { void *hdr; @@ -392,42 +467,143 @@ static int smc_pnet_dumpinfo(struct sk_buff *skb, return 0; } -static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb) +static int _smc_pnet_dump(struct net *net, struct sk_buff *skb, u32 portid, + u32 seq, u8 *pnetid, int start_idx) { + struct smc_user_pnetentry tmp_entry; + struct smc_pnettable *pnettable; struct smc_pnetentry *pnetelem; + struct smc_ib_device *ibdev; + struct smcd_dev *smcd_dev; + struct smc_net *sn; int idx = 0; + int ibport; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; - read_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { - if (idx++ < cb->args[0]) + /* dump netdevices */ + read_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { + if (pnetid && !smc_pnet_match(pnetelem->pnet_name, pnetid)) + continue; + if (idx++ < start_idx) continue; - if (smc_pnet_dumpinfo(skb, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, NLM_F_MULTI, - pnetelem)) { + memset(&tmp_entry, 0, sizeof(tmp_entry)); + memcpy(&tmp_entry.pnet_name, pnetelem->pnet_name, + SMC_MAX_PNETID_LEN); + tmp_entry.ndev = pnetelem->ndev; + if (smc_pnet_dumpinfo(skb, portid, seq, NLM_F_MULTI, + &tmp_entry)) { --idx; break; } } + read_unlock(&pnettable->lock); + + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return idx; + + /* dump ib devices */ + spin_lock(&smc_ib_devices.lock); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + for (ibport = 0; ibport < SMC_MAX_PORTS; ibport++) { + if (ibdev->pnetid_by_user[ibport]) { + if (pnetid && + !smc_pnet_match(ibdev->pnetid[ibport], + pnetid)) + continue; + if (idx++ < start_idx) + continue; + memset(&tmp_entry, 0, sizeof(tmp_entry)); + memcpy(&tmp_entry.pnet_name, + ibdev->pnetid[ibport], + SMC_MAX_PNETID_LEN); + tmp_entry.smcibdev = ibdev; + tmp_entry.ib_port = ibport + 1; + if (smc_pnet_dumpinfo(skb, portid, seq, + NLM_F_MULTI, + &tmp_entry)) { + --idx; + break; + } + } + } + } + spin_unlock(&smc_ib_devices.lock); + + /* dump smcd devices */ + spin_lock(&smcd_dev_list.lock); + list_for_each_entry(smcd_dev, &smcd_dev_list.list, list) { + if (smcd_dev->pnetid_by_user) { + if (pnetid && !smc_pnet_match(smcd_dev->pnetid, pnetid)) + continue; + if (idx++ < start_idx) + continue; + memset(&tmp_entry, 0, sizeof(tmp_entry)); + memcpy(&tmp_entry.pnet_name, smcd_dev->pnetid, + SMC_MAX_PNETID_LEN); + tmp_entry.smcd_dev = smcd_dev; + if (smc_pnet_dumpinfo(skb, portid, seq, NLM_F_MULTI, + &tmp_entry)) { + --idx; + break; + } + } + } + spin_unlock(&smcd_dev_list.lock); + + return idx; +} + +static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int idx; + + idx = _smc_pnet_dump(net, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NULL, cb->args[0]); + cb->args[0] = idx; - read_unlock(&smc_pnettable.lock); return skb->len; } +/* Retrieve one PNETID entry */ +static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct sk_buff *msg; + void *hdr; + + if (!info->attrs[SMC_PNETID_NAME]) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + _smc_pnet_dump(net, msg, info->snd_portid, info->snd_seq, + nla_data(info->attrs[SMC_PNETID_NAME]), 0); + + /* finish multi part message and send it */ + hdr = nlmsg_put(msg, info->snd_portid, info->snd_seq, NLMSG_DONE, 0, + NLM_F_MULTI); + if (!hdr) { + nlmsg_free(msg); + return -EMSGSIZE; + } + return genlmsg_reply(msg, info); +} + /* Remove and delete all pnetids from pnet table. */ static int smc_pnet_flush(struct sk_buff *skb, struct genl_info *info) { - struct smc_pnetentry *pnetelem, *tmp_pe; + struct net *net = genl_info_net(info); - write_lock(&smc_pnettable.lock); - list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist, - list) { - list_del(&pnetelem->list); - dev_put(pnetelem->ndev); - kfree(pnetelem); - } - write_unlock(&smc_pnettable.lock); - return 0; + return smc_pnet_remove_by_pnetid(net, NULL); } /* SMC_PNETID generic netlink operation definition */ @@ -491,6 +667,18 @@ static struct notifier_block smc_netdev_notifier = { .notifier_call = smc_pnet_netdev_event }; +/* init network namespace */ +int smc_pnet_net_init(struct net *net) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnettable *pnettable = &sn->pnettable; + + INIT_LIST_HEAD(&pnettable->pnetlist); + rwlock_init(&pnettable->lock); + + return 0; +} + int __init smc_pnet_init(void) { int rc; @@ -504,9 +692,15 @@ int __init smc_pnet_init(void) return rc; } +/* exit network namespace */ +void smc_pnet_net_exit(struct net *net) +{ + /* flush pnet table */ + smc_pnet_remove_by_pnetid(net, NULL); +} + void smc_pnet_exit(void) { - smc_pnet_flush(NULL, NULL); unregister_netdevice_notifier(&smc_netdev_notifier); genl_unregister_family(&smc_pnet_nl_family); } @@ -534,9 +728,73 @@ static struct net_device *pnet_find_base_ndev(struct net_device *ndev) return ndev; } +static int smc_pnet_find_ndev_pnetid_by_table(struct net_device *ndev, + u8 *pnetid) +{ + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); + struct smc_pnetentry *pnetelem; + struct smc_net *sn; + int rc = -ENOENT; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + read_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { + if (ndev == pnetelem->ndev) { + /* get pnetid of netdev device */ + memcpy(pnetid, pnetelem->pnet_name, SMC_MAX_PNETID_LEN); + rc = 0; + break; + } + } + read_unlock(&pnettable->lock); + return rc; +} + +/* if handshake network device belongs to a roce device, return its + * IB device and port + */ +static void smc_pnet_find_rdma_dev(struct net_device *netdev, + struct smc_ib_device **smcibdev, + u8 *ibport, unsigned short vlan_id, u8 gid[]) +{ + struct smc_ib_device *ibdev; + + spin_lock(&smc_ib_devices.lock); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + struct net_device *ndev; + int i; + + for (i = 1; i <= SMC_MAX_PORTS; i++) { + if (!rdma_is_port_valid(ibdev->ibdev, i)) + continue; + if (!ibdev->ibdev->ops.get_netdev) + continue; + ndev = ibdev->ibdev->ops.get_netdev(ibdev->ibdev, i); + if (!ndev) + continue; + dev_put(ndev); + if (netdev == ndev && + smc_ib_port_active(ibdev, i) && + !smc_ib_determine_gid(ibdev, i, vlan_id, gid, + NULL)) { + *smcibdev = ibdev; + *ibport = i; + break; + } + } + } + spin_unlock(&smc_ib_devices.lock); +} + /* Determine the corresponding IB device port based on the hardware PNETID. * Searching stops at the first matching active IB device port with vlan_id * configured. + * If nothing found, check pnetid table. + * If nothing found, try to use handshake device */ static void smc_pnet_find_roce_by_pnetid(struct net_device *ndev, struct smc_ib_device **smcibdev, @@ -549,16 +807,18 @@ static void smc_pnet_find_roce_by_pnetid(struct net_device *ndev, ndev = pnet_find_base_ndev(ndev); if (smc_pnetid_by_dev_port(ndev->dev.parent, ndev->dev_port, - ndev_pnetid)) + ndev_pnetid) && + smc_pnet_find_ndev_pnetid_by_table(ndev, ndev_pnetid)) { + smc_pnet_find_rdma_dev(ndev, smcibdev, ibport, vlan_id, gid); return; /* pnetid could not be determined */ + } spin_lock(&smc_ib_devices.lock); list_for_each_entry(ibdev, &smc_ib_devices.list, list) { for (i = 1; i <= SMC_MAX_PORTS; i++) { if (!rdma_is_port_valid(ibdev->ibdev, i)) continue; - if (!memcmp(ibdev->pnetid[i - 1], ndev_pnetid, - SMC_MAX_PNETID_LEN) && + if (smc_pnet_match(ibdev->pnetid[i - 1], ndev_pnetid) && smc_ib_port_active(ibdev, i) && !smc_ib_determine_gid(ibdev, i, vlan_id, gid, NULL)) { @@ -580,12 +840,13 @@ static void smc_pnet_find_ism_by_pnetid(struct net_device *ndev, ndev = pnet_find_base_ndev(ndev); if (smc_pnetid_by_dev_port(ndev->dev.parent, ndev->dev_port, - ndev_pnetid)) + ndev_pnetid) && + smc_pnet_find_ndev_pnetid_by_table(ndev, ndev_pnetid)) return; /* pnetid could not be determined */ spin_lock(&smcd_dev_list.lock); list_for_each_entry(ismdev, &smcd_dev_list.list, list) { - if (!memcmp(ismdev->pnetid, ndev_pnetid, SMC_MAX_PNETID_LEN)) { + if (smc_pnet_match(ismdev->pnetid, ndev_pnetid)) { *smcismdev = ismdev; break; } @@ -593,31 +854,6 @@ static void smc_pnet_find_ism_by_pnetid(struct net_device *ndev, spin_unlock(&smcd_dev_list.lock); } -/* Lookup of coupled ib_device via SMC pnet table */ -static void smc_pnet_find_roce_by_table(struct net_device *netdev, - struct smc_ib_device **smcibdev, - u8 *ibport, unsigned short vlan_id, - u8 gid[]) -{ - struct smc_pnetentry *pnetelem; - - read_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { - if (netdev == pnetelem->ndev) { - if (smc_ib_port_active(pnetelem->smcibdev, - pnetelem->ib_port) && - !smc_ib_determine_gid(pnetelem->smcibdev, - pnetelem->ib_port, vlan_id, - gid, NULL)) { - *smcibdev = pnetelem->smcibdev; - *ibport = pnetelem->ib_port; - } - break; - } - } - read_unlock(&smc_pnettable.lock); -} - /* PNET table analysis for a given sock: * determine ib_device and port belonging to used internal TCP socket * ethernet interface. @@ -636,13 +872,7 @@ void smc_pnet_find_roce_resource(struct sock *sk, if (!dst->dev) goto out_rel; - /* if possible, lookup via hardware-defined pnetid */ smc_pnet_find_roce_by_pnetid(dst->dev, smcibdev, ibport, vlan_id, gid); - if (*smcibdev) - goto out_rel; - - /* lookup via SMC PNET table */ - smc_pnet_find_roce_by_table(dst->dev, smcibdev, ibport, vlan_id, gid); out_rel: dst_release(dst); @@ -660,7 +890,6 @@ void smc_pnet_find_ism_resource(struct sock *sk, struct smcd_dev **smcismdev) if (!dst->dev) goto out_rel; - /* if possible, lookup via hardware-defined pnetid */ smc_pnet_find_ism_by_pnetid(dst->dev, smcismdev); out_rel: diff --git a/net/smc/smc_pnet.h b/net/smc/smc_pnet.h index 8ff777636e32..5eac42fb45d0 100644 --- a/net/smc/smc_pnet.h +++ b/net/smc/smc_pnet.h @@ -19,6 +19,16 @@ struct smc_ib_device; struct smcd_dev; +/** + * struct smc_pnettable - SMC PNET table anchor + * @lock: Lock for list action + * @pnetlist: List of PNETIDs + */ +struct smc_pnettable { + rwlock_t lock; + struct list_head pnetlist; +}; + static inline int smc_pnetid_by_dev_port(struct device *dev, unsigned short port, u8 *pnetid) { @@ -30,8 +40,9 @@ static inline int smc_pnetid_by_dev_port(struct device *dev, } int smc_pnet_init(void) __init; +int smc_pnet_net_init(struct net *net); void smc_pnet_exit(void); -int smc_pnet_remove_by_ibdev(struct smc_ib_device *ibdev); +void smc_pnet_net_exit(struct net *net); void smc_pnet_find_roce_resource(struct sock *sk, struct smc_ib_device **smcibdev, u8 *ibport, unsigned short vlan_id, u8 gid[]); diff --git a/net/smc/smc_rx.c b/net/smc/smc_rx.c index bbcf0fe4ae10..413a6abf227e 100644 --- a/net/smc/smc_rx.c +++ b/net/smc/smc_rx.c @@ -136,7 +136,6 @@ static int smc_rx_pipe_buf_nosteal(struct pipe_inode_info *pipe, } static const struct pipe_buf_operations smc_pipe_ops = { - .can_merge = 0, .confirm = generic_pipe_buf_confirm, .release = smc_rx_pipe_buf_release, .steal = smc_rx_pipe_buf_nosteal, diff --git a/net/smc/smc_tx.c b/net/smc/smc_tx.c index d8366ed51757..f0de323d15d6 100644 --- a/net/smc/smc_tx.c +++ b/net/smc/smc_tx.c @@ -24,10 +24,11 @@ #include "smc.h" #include "smc_wr.h" #include "smc_cdc.h" +#include "smc_close.h" #include "smc_ism.h" #include "smc_tx.h" -#define SMC_TX_WORK_DELAY HZ +#define SMC_TX_WORK_DELAY 0 #define SMC_TX_CORK_DELAY (HZ >> 2) /* 250 ms */ /***************************** sndbuf producer *******************************/ @@ -165,12 +166,11 @@ int smc_tx_sendmsg(struct smc_sock *smc, struct msghdr *msg, size_t len) conn->local_tx_ctrl.prod_flags.urg_data_pending = 1; if (!atomic_read(&conn->sndbuf_space) || conn->urg_tx_pend) { + if (send_done) + return send_done; rc = smc_tx_wait(smc, msg->msg_flags); - if (rc) { - if (send_done) - return send_done; + if (rc) goto out_err; - } continue; } @@ -267,27 +267,23 @@ int smcd_tx_ism_write(struct smc_connection *conn, void *data, size_t len, /* sndbuf consumer: actual data transfer of one target chunk with RDMA write */ static int smc_tx_rdma_write(struct smc_connection *conn, int peer_rmbe_offset, - int num_sges, struct ib_sge sges[]) + int num_sges, struct ib_rdma_wr *rdma_wr) { struct smc_link_group *lgr = conn->lgr; - struct ib_rdma_wr rdma_wr; struct smc_link *link; int rc; - memset(&rdma_wr, 0, sizeof(rdma_wr)); link = &lgr->lnk[SMC_SINGLE_LINK]; - rdma_wr.wr.wr_id = smc_wr_tx_get_next_wr_id(link); - rdma_wr.wr.sg_list = sges; - rdma_wr.wr.num_sge = num_sges; - rdma_wr.wr.opcode = IB_WR_RDMA_WRITE; - rdma_wr.remote_addr = + rdma_wr->wr.wr_id = smc_wr_tx_get_next_wr_id(link); + rdma_wr->wr.num_sge = num_sges; + rdma_wr->remote_addr = lgr->rtokens[conn->rtoken_idx][SMC_SINGLE_LINK].dma_addr + /* RMBE within RMB */ conn->tx_off + /* offset within RMBE */ peer_rmbe_offset; - rdma_wr.rkey = lgr->rtokens[conn->rtoken_idx][SMC_SINGLE_LINK].rkey; - rc = ib_post_send(link->roce_qp, &rdma_wr.wr, NULL); + rdma_wr->rkey = lgr->rtokens[conn->rtoken_idx][SMC_SINGLE_LINK].rkey; + rc = ib_post_send(link->roce_qp, &rdma_wr->wr, NULL); if (rc) { conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; smc_lgr_terminate(lgr); @@ -314,24 +310,25 @@ static inline void smc_tx_advance_cursors(struct smc_connection *conn, /* SMC-R helper for smc_tx_rdma_writes() */ static int smcr_tx_rdma_writes(struct smc_connection *conn, size_t len, size_t src_off, size_t src_len, - size_t dst_off, size_t dst_len) + size_t dst_off, size_t dst_len, + struct smc_rdma_wr *wr_rdma_buf) { dma_addr_t dma_addr = sg_dma_address(conn->sndbuf_desc->sgt[SMC_SINGLE_LINK].sgl); - struct smc_link *link = &conn->lgr->lnk[SMC_SINGLE_LINK]; int src_len_sum = src_len, dst_len_sum = dst_len; - struct ib_sge sges[SMC_IB_MAX_SEND_SGE]; int sent_count = src_off; int srcchunk, dstchunk; int num_sges; int rc; for (dstchunk = 0; dstchunk < 2; dstchunk++) { + struct ib_sge *sge = + wr_rdma_buf->wr_tx_rdma[dstchunk].wr.sg_list; + num_sges = 0; for (srcchunk = 0; srcchunk < 2; srcchunk++) { - sges[srcchunk].addr = dma_addr + src_off; - sges[srcchunk].length = src_len; - sges[srcchunk].lkey = link->roce_pd->local_dma_lkey; + sge[srcchunk].addr = dma_addr + src_off; + sge[srcchunk].length = src_len; num_sges++; src_off += src_len; @@ -344,7 +341,8 @@ static int smcr_tx_rdma_writes(struct smc_connection *conn, size_t len, src_len = dst_len - src_len; /* remainder */ src_len_sum += src_len; } - rc = smc_tx_rdma_write(conn, dst_off, num_sges, sges); + rc = smc_tx_rdma_write(conn, dst_off, num_sges, + &wr_rdma_buf->wr_tx_rdma[dstchunk]); if (rc) return rc; if (dst_len_sum == len) @@ -403,7 +401,8 @@ static int smcd_tx_rdma_writes(struct smc_connection *conn, size_t len, /* sndbuf consumer: prepare all necessary (src&dst) chunks of data transmit; * usable snd_wnd as max transmit */ -static int smc_tx_rdma_writes(struct smc_connection *conn) +static int smc_tx_rdma_writes(struct smc_connection *conn, + struct smc_rdma_wr *wr_rdma_buf) { size_t len, src_len, dst_off, dst_len; /* current chunk values */ union smc_host_cursor sent, prep, prod, cons; @@ -464,7 +463,7 @@ static int smc_tx_rdma_writes(struct smc_connection *conn) dst_off, dst_len); else rc = smcr_tx_rdma_writes(conn, len, sent.count, src_len, - dst_off, dst_len); + dst_off, dst_len, wr_rdma_buf); if (rc) return rc; @@ -484,32 +483,31 @@ static int smc_tx_rdma_writes(struct smc_connection *conn) */ static int smcr_tx_sndbuf_nonempty(struct smc_connection *conn) { - struct smc_cdc_producer_flags *pflags; + struct smc_cdc_producer_flags *pflags = &conn->local_tx_ctrl.prod_flags; + struct smc_rdma_wr *wr_rdma_buf; struct smc_cdc_tx_pend *pend; struct smc_wr_buf *wr_buf; int rc; - spin_lock_bh(&conn->send_lock); - rc = smc_cdc_get_free_slot(conn, &wr_buf, &pend); + rc = smc_cdc_get_free_slot(conn, &wr_buf, &wr_rdma_buf, &pend); if (rc < 0) { if (rc == -EBUSY) { struct smc_sock *smc = container_of(conn, struct smc_sock, conn); - if (smc->sk.sk_err == ECONNABORTED) { - rc = sock_error(&smc->sk); - goto out_unlock; - } + if (smc->sk.sk_err == ECONNABORTED) + return sock_error(&smc->sk); rc = 0; if (conn->alert_token_local) /* connection healthy */ mod_delayed_work(system_wq, &conn->tx_work, SMC_TX_WORK_DELAY); } - goto out_unlock; + return rc; } - if (!conn->local_tx_ctrl.prod_flags.urg_data_present) { - rc = smc_tx_rdma_writes(conn); + spin_lock_bh(&conn->send_lock); + if (!pflags->urg_data_present) { + rc = smc_tx_rdma_writes(conn, wr_rdma_buf); if (rc) { smc_wr_tx_put_slot(&conn->lgr->lnk[SMC_SINGLE_LINK], (struct smc_wr_tx_pend_priv *)pend); @@ -518,7 +516,6 @@ static int smcr_tx_sndbuf_nonempty(struct smc_connection *conn) } rc = smc_cdc_msg_send(conn, wr_buf, pend); - pflags = &conn->local_tx_ctrl.prod_flags; if (!rc && pflags->urg_data_present) { pflags->urg_data_pending = 0; pflags->urg_data_present = 0; @@ -536,7 +533,7 @@ static int smcd_tx_sndbuf_nonempty(struct smc_connection *conn) spin_lock_bh(&conn->send_lock); if (!pflags->urg_data_present) - rc = smc_tx_rdma_writes(conn); + rc = smc_tx_rdma_writes(conn, NULL); if (!rc) rc = smcd_cdc_msg_send(conn); @@ -557,6 +554,12 @@ int smc_tx_sndbuf_nonempty(struct smc_connection *conn) else rc = smcr_tx_sndbuf_nonempty(conn); + if (!rc) { + /* trigger socket release if connection is closing */ + struct smc_sock *smc = container_of(conn, struct smc_sock, + conn); + smc_close_wake_tx_prepared(smc); + } return rc; } @@ -598,7 +601,8 @@ void smc_tx_consumer_update(struct smc_connection *conn, bool force) if (to_confirm > conn->rmbe_update_limit) { smc_curs_copy(&prod, &conn->local_rx_ctrl.prod, conn); sender_free = conn->rmb_desc->len - - smc_curs_diff(conn->rmb_desc->len, &prod, &cfed); + smc_curs_diff_large(conn->rmb_desc->len, + &cfed, &prod); } if (conn->local_rx_ctrl.prod_flags.cons_curs_upd_req || @@ -612,9 +616,6 @@ void smc_tx_consumer_update(struct smc_connection *conn, bool force) SMC_TX_WORK_DELAY); return; } - smc_curs_copy(&conn->rx_curs_confirmed, - &conn->local_tx_ctrl.cons, conn); - conn->local_rx_ctrl.prod_flags.cons_curs_upd_req = 0; } if (conn->local_rx_ctrl.prod_flags.write_blocked && !atomic_read(&conn->bytes_to_rcv)) diff --git a/net/smc/smc_wr.c b/net/smc/smc_wr.c index c2694750a6a8..253aa75dc2b6 100644 --- a/net/smc/smc_wr.c +++ b/net/smc/smc_wr.c @@ -160,6 +160,7 @@ static inline int smc_wr_tx_get_free_slot_index(struct smc_link *link, u32 *idx) * @link: Pointer to smc_link used to later send the message. * @handler: Send completion handler function pointer. * @wr_buf: Out value returns pointer to message buffer. + * @wr_rdma_buf: Out value returns pointer to rdma work request. * @wr_pend_priv: Out value returns pointer serving as handler context. * * Return: 0 on success, or -errno on error. @@ -167,6 +168,7 @@ static inline int smc_wr_tx_get_free_slot_index(struct smc_link *link, u32 *idx) int smc_wr_tx_get_free_slot(struct smc_link *link, smc_wr_tx_handler handler, struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wr_rdma_buf, struct smc_wr_tx_pend_priv **wr_pend_priv) { struct smc_wr_tx_pend *wr_pend; @@ -204,6 +206,8 @@ int smc_wr_tx_get_free_slot(struct smc_link *link, wr_ib = &link->wr_tx_ibs[idx]; wr_ib->wr_id = wr_id; *wr_buf = &link->wr_tx_bufs[idx]; + if (wr_rdma_buf) + *wr_rdma_buf = &link->wr_tx_rdmas[idx]; *wr_pend_priv = &wr_pend->priv; return 0; } @@ -218,10 +222,10 @@ int smc_wr_tx_put_slot(struct smc_link *link, u32 idx = pend->idx; /* clear the full struct smc_wr_tx_pend including .priv */ - memset(&link->wr_tx_pends[pend->idx], 0, - sizeof(link->wr_tx_pends[pend->idx])); - memset(&link->wr_tx_bufs[pend->idx], 0, - sizeof(link->wr_tx_bufs[pend->idx])); + memset(&link->wr_tx_pends[idx], 0, + sizeof(link->wr_tx_pends[idx])); + memset(&link->wr_tx_bufs[idx], 0, + sizeof(link->wr_tx_bufs[idx])); test_and_clear_bit(idx, link->wr_tx_mask); return 1; } @@ -465,12 +469,26 @@ static void smc_wr_init_sge(struct smc_link *lnk) lnk->wr_tx_dma_addr + i * SMC_WR_BUF_SIZE; lnk->wr_tx_sges[i].length = SMC_WR_TX_SIZE; lnk->wr_tx_sges[i].lkey = lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[0].wr_tx_rdma_sge[0].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[0].wr_tx_rdma_sge[1].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[1].wr_tx_rdma_sge[0].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[1].wr_tx_rdma_sge[1].lkey = + lnk->roce_pd->local_dma_lkey; lnk->wr_tx_ibs[i].next = NULL; lnk->wr_tx_ibs[i].sg_list = &lnk->wr_tx_sges[i]; lnk->wr_tx_ibs[i].num_sge = 1; lnk->wr_tx_ibs[i].opcode = IB_WR_SEND; lnk->wr_tx_ibs[i].send_flags = IB_SEND_SIGNALED | IB_SEND_SOLICITED; + lnk->wr_tx_rdmas[i].wr_tx_rdma[0].wr.opcode = IB_WR_RDMA_WRITE; + lnk->wr_tx_rdmas[i].wr_tx_rdma[1].wr.opcode = IB_WR_RDMA_WRITE; + lnk->wr_tx_rdmas[i].wr_tx_rdma[0].wr.sg_list = + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[0].wr_tx_rdma_sge; + lnk->wr_tx_rdmas[i].wr_tx_rdma[1].wr.sg_list = + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[1].wr_tx_rdma_sge; } for (i = 0; i < lnk->wr_rx_cnt; i++) { lnk->wr_rx_sges[i].addr = @@ -521,8 +539,12 @@ void smc_wr_free_link_mem(struct smc_link *lnk) lnk->wr_tx_mask = NULL; kfree(lnk->wr_tx_sges); lnk->wr_tx_sges = NULL; + kfree(lnk->wr_tx_rdma_sges); + lnk->wr_tx_rdma_sges = NULL; kfree(lnk->wr_rx_sges); lnk->wr_rx_sges = NULL; + kfree(lnk->wr_tx_rdmas); + lnk->wr_tx_rdmas = NULL; kfree(lnk->wr_rx_ibs); lnk->wr_rx_ibs = NULL; kfree(lnk->wr_tx_ibs); @@ -552,10 +574,20 @@ int smc_wr_alloc_link_mem(struct smc_link *link) GFP_KERNEL); if (!link->wr_rx_ibs) goto no_mem_wr_tx_ibs; + link->wr_tx_rdmas = kcalloc(SMC_WR_BUF_CNT, + sizeof(link->wr_tx_rdmas[0]), + GFP_KERNEL); + if (!link->wr_tx_rdmas) + goto no_mem_wr_rx_ibs; + link->wr_tx_rdma_sges = kcalloc(SMC_WR_BUF_CNT, + sizeof(link->wr_tx_rdma_sges[0]), + GFP_KERNEL); + if (!link->wr_tx_rdma_sges) + goto no_mem_wr_tx_rdmas; link->wr_tx_sges = kcalloc(SMC_WR_BUF_CNT, sizeof(link->wr_tx_sges[0]), GFP_KERNEL); if (!link->wr_tx_sges) - goto no_mem_wr_rx_ibs; + goto no_mem_wr_tx_rdma_sges; link->wr_rx_sges = kcalloc(SMC_WR_BUF_CNT * 3, sizeof(link->wr_rx_sges[0]), GFP_KERNEL); @@ -579,6 +611,10 @@ no_mem_wr_rx_sges: kfree(link->wr_rx_sges); no_mem_wr_tx_sges: kfree(link->wr_tx_sges); +no_mem_wr_tx_rdma_sges: + kfree(link->wr_tx_rdma_sges); +no_mem_wr_tx_rdmas: + kfree(link->wr_tx_rdmas); no_mem_wr_rx_ibs: kfree(link->wr_rx_ibs); no_mem_wr_tx_ibs: diff --git a/net/smc/smc_wr.h b/net/smc/smc_wr.h index 1d85bb14fd6f..09bf32fd3959 100644 --- a/net/smc/smc_wr.h +++ b/net/smc/smc_wr.h @@ -85,6 +85,7 @@ void smc_wr_add_dev(struct smc_ib_device *smcibdev); int smc_wr_tx_get_free_slot(struct smc_link *link, smc_wr_tx_handler handler, struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wrs, struct smc_wr_tx_pend_priv **wr_pend_priv); int smc_wr_tx_put_slot(struct smc_link *link, struct smc_wr_tx_pend_priv *wr_pend_priv); diff --git a/net/socket.c b/net/socket.c index e89884e2197b..3c176a12fe48 100644 --- a/net/socket.c +++ b/net/socket.c @@ -577,6 +577,7 @@ static void __sock_release(struct socket *sock, struct inode *inode) if (inode) inode_lock(inode); sock->ops->release(sock); + sock->sk = NULL; if (inode) inode_unlock(inode); sock->ops = NULL; @@ -669,7 +670,7 @@ static bool skb_is_err_queue(const struct sk_buff *skb) * before the software timestamp is received, a hardware TX timestamp may be * returned only if there is no software TX timestamp. Ignore false software * timestamps, which may be made in the __sock_recv_timestamp() call when the - * option SO_TIMESTAMP(NS) is enabled on the socket, even when the skb has a + * option SO_TIMESTAMP_OLD(NS) is enabled on the socket, even when the skb has a * hardware timestamp. */ static bool skb_is_swtx_tstamp(const struct sk_buff *skb, int false_tstamp) @@ -705,7 +706,9 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk, struct sk_buff *skb) { int need_software_tstamp = sock_flag(sk, SOCK_RCVTSTAMP); - struct scm_timestamping tss; + int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW); + struct scm_timestamping_internal tss; + int empty = 1, false_tstamp = 0; struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb); @@ -719,34 +722,54 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk, if (need_software_tstamp) { if (!sock_flag(sk, SOCK_RCVTSTAMPNS)) { - struct timeval tv; - skb_get_timestamp(skb, &tv); - put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMP, - sizeof(tv), &tv); + if (new_tstamp) { + struct __kernel_sock_timeval tv; + + skb_get_new_timestamp(skb, &tv); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_NEW, + sizeof(tv), &tv); + } else { + struct __kernel_old_timeval tv; + + skb_get_timestamp(skb, &tv); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_OLD, + sizeof(tv), &tv); + } } else { - struct timespec ts; - skb_get_timestampns(skb, &ts); - put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMPNS, - sizeof(ts), &ts); + if (new_tstamp) { + struct __kernel_timespec ts; + + skb_get_new_timestampns(skb, &ts); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_NEW, + sizeof(ts), &ts); + } else { + struct timespec ts; + + skb_get_timestampns(skb, &ts); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_OLD, + sizeof(ts), &ts); + } } } memset(&tss, 0, sizeof(tss)); if ((sk->sk_tsflags & SOF_TIMESTAMPING_SOFTWARE) && - ktime_to_timespec_cond(skb->tstamp, tss.ts + 0)) + ktime_to_timespec64_cond(skb->tstamp, tss.ts + 0)) empty = 0; if (shhwtstamps && (sk->sk_tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) && !skb_is_swtx_tstamp(skb, false_tstamp) && - ktime_to_timespec_cond(shhwtstamps->hwtstamp, tss.ts + 2)) { + ktime_to_timespec64_cond(shhwtstamps->hwtstamp, tss.ts + 2)) { empty = 0; if ((sk->sk_tsflags & SOF_TIMESTAMPING_OPT_PKTINFO) && !skb_is_err_queue(skb)) put_ts_pktinfo(msg, skb); } if (!empty) { - put_cmsg(msg, SOL_SOCKET, - SCM_TIMESTAMPING, sizeof(tss), &tss); + if (sock_flag(sk, SOCK_TSTAMP_NEW)) + put_cmsg_scm_timestamping64(msg, &tss); + else + put_cmsg_scm_timestamping(msg, &tss); if (skb_is_err_queue(skb) && skb->len && SKB_EXT_ERR(skb)->opt_stats) @@ -941,8 +964,7 @@ void dlci_ioctl_set(int (*hook) (unsigned int, void __user *)) EXPORT_SYMBOL(dlci_ioctl_set); static long sock_do_ioctl(struct net *net, struct socket *sock, - unsigned int cmd, unsigned long arg, - unsigned int ifreq_size) + unsigned int cmd, unsigned long arg) { int err; void __user *argp = (void __user *)arg; @@ -968,11 +990,11 @@ static long sock_do_ioctl(struct net *net, struct socket *sock, } else { struct ifreq ifr; bool need_copyout; - if (copy_from_user(&ifr, argp, ifreq_size)) + if (copy_from_user(&ifr, argp, sizeof(struct ifreq))) return -EFAULT; err = dev_ioctl(net, cmd, &ifr, &need_copyout); if (!err && need_copyout) - if (copy_to_user(argp, &ifr, ifreq_size)) + if (copy_to_user(argp, &ifr, sizeof(struct ifreq))) return -EFAULT; } return err; @@ -1071,8 +1093,7 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) err = open_related_ns(&net->ns, get_net_ns); break; default: - err = sock_do_ioctl(net, sock, cmd, arg, - sizeof(struct ifreq)); + err = sock_do_ioctl(net, sock, cmd, arg); break; } return err; @@ -2780,8 +2801,7 @@ static int do_siocgstamp(struct net *net, struct socket *sock, int err; set_fs(KERNEL_DS); - err = sock_do_ioctl(net, sock, cmd, (unsigned long)&ktv, - sizeof(struct compat_ifreq)); + err = sock_do_ioctl(net, sock, cmd, (unsigned long)&ktv); set_fs(old_fs); if (!err) err = compat_put_timeval(&ktv, up); @@ -2797,8 +2817,7 @@ static int do_siocgstampns(struct net *net, struct socket *sock, int err; set_fs(KERNEL_DS); - err = sock_do_ioctl(net, sock, cmd, (unsigned long)&kts, - sizeof(struct compat_ifreq)); + err = sock_do_ioctl(net, sock, cmd, (unsigned long)&kts); set_fs(old_fs); if (!err) err = compat_put_timespec(&kts, up); @@ -2994,6 +3013,54 @@ static int compat_ifr_data_ioctl(struct net *net, unsigned int cmd, return dev_ioctl(net, cmd, &ifreq, NULL); } +static int compat_ifreq_ioctl(struct net *net, struct socket *sock, + unsigned int cmd, + struct compat_ifreq __user *uifr32) +{ + struct ifreq __user *uifr; + int err; + + /* Handle the fact that while struct ifreq has the same *layout* on + * 32/64 for everything but ifreq::ifru_ifmap and ifreq::ifru_data, + * which are handled elsewhere, it still has different *size* due to + * ifreq::ifru_ifmap (which is 16 bytes on 32 bit, 24 bytes on 64-bit, + * resulting in struct ifreq being 32 and 40 bytes respectively). + * As a result, if the struct happens to be at the end of a page and + * the next page isn't readable/writable, we get a fault. To prevent + * that, copy back and forth to the full size. + */ + + uifr = compat_alloc_user_space(sizeof(*uifr)); + if (copy_in_user(uifr, uifr32, sizeof(*uifr32))) + return -EFAULT; + + err = sock_do_ioctl(net, sock, cmd, (unsigned long)uifr); + + if (!err) { + switch (cmd) { + case SIOCGIFFLAGS: + case SIOCGIFMETRIC: + case SIOCGIFMTU: + case SIOCGIFMEM: + case SIOCGIFHWADDR: + case SIOCGIFINDEX: + case SIOCGIFADDR: + case SIOCGIFBRDADDR: + case SIOCGIFDSTADDR: + case SIOCGIFNETMASK: + case SIOCGIFPFLAGS: + case SIOCGIFTXQLEN: + case SIOCGMIIPHY: + case SIOCGMIIREG: + case SIOCGIFNAME: + if (copy_in_user(uifr32, uifr, sizeof(*uifr32))) + err = -EFAULT; + break; + } + } + return err; +} + static int compat_sioc_ifmap(struct net *net, unsigned int cmd, struct compat_ifreq __user *uifr32) { @@ -3109,8 +3176,7 @@ static int routing_ioctl(struct net *net, struct socket *sock, } set_fs(KERNEL_DS); - ret = sock_do_ioctl(net, sock, cmd, (unsigned long) r, - sizeof(struct compat_ifreq)); + ret = sock_do_ioctl(net, sock, cmd, (unsigned long) r); set_fs(old_fs); out: @@ -3210,21 +3276,22 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock, case SIOCSIFTXQLEN: case SIOCBRADDIF: case SIOCBRDELIF: + case SIOCGIFNAME: case SIOCSIFNAME: case SIOCGMIIPHY: case SIOCGMIIREG: case SIOCSMIIREG: - case SIOCSARP: - case SIOCGARP: - case SIOCDARP: - case SIOCATMARK: case SIOCBONDENSLAVE: case SIOCBONDRELEASE: case SIOCBONDSETHWADDR: case SIOCBONDCHANGEACTIVE: - case SIOCGIFNAME: - return sock_do_ioctl(net, sock, cmd, arg, - sizeof(struct compat_ifreq)); + return compat_ifreq_ioctl(net, sock, cmd, argp); + + case SIOCSARP: + case SIOCGARP: + case SIOCDARP: + case SIOCATMARK: + return sock_do_ioctl(net, sock, cmd, arg); } return -ENOIOCTLCMD; diff --git a/net/sunrpc/Kconfig b/net/sunrpc/Kconfig index ac09ca803296..83f5617bae07 100644 --- a/net/sunrpc/Kconfig +++ b/net/sunrpc/Kconfig @@ -34,6 +34,22 @@ config RPCSEC_GSS_KRB5 If unsure, say Y. +config CONFIG_SUNRPC_DISABLE_INSECURE_ENCTYPES + bool "Secure RPC: Disable insecure Kerberos encryption types" + depends on RPCSEC_GSS_KRB5 + default n + help + Choose Y here to disable the use of deprecated encryption types + with the Kerberos version 5 GSS-API mechanism (RFC 1964). The + deprecated encryption types include DES-CBC-MD5, DES-CBC-CRC, + and DES-CBC-MD4. These types were deprecated by RFC 6649 because + they were found to be insecure. + + N is the default because many sites have deployed KDCs and + keytabs that contain only these deprecated encryption types. + Choosing Y prevents the use of known-insecure encryption types + but might result in compatibility problems. + config SUNRPC_DEBUG bool "RPC: Enable dprintk debugging" depends on SUNRPC && SYSCTL diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index f3023bbc0b7f..e7861026b9e5 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -17,9 +17,7 @@ #include <linux/sunrpc/gss_api.h> #include <linux/spinlock.h> -#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) -# define RPCDBG_FACILITY RPCDBG_AUTH -#endif +#include <trace/events/sunrpc.h> #define RPC_CREDCACHE_DEFAULT_HASHBITS (4) struct rpc_cred_cache { @@ -267,8 +265,6 @@ rpcauth_list_flavors(rpc_authflavor_t *array, int size) } } rcu_read_unlock(); - - dprintk("RPC: %s returns %d\n", __func__, result); return result; } EXPORT_SYMBOL_GPL(rpcauth_list_flavors); @@ -636,9 +632,6 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags) struct rpc_cred *ret; const struct cred *cred = current_cred(); - dprintk("RPC: looking up %s cred\n", - auth->au_ops->au_name); - memset(&acred, 0, sizeof(acred)); acred.cred = cred; ret = auth->au_ops->lookup_cred(auth, &acred, flags); @@ -670,8 +663,6 @@ rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) }; struct rpc_cred *ret; - dprintk("RPC: %5u looking up %s cred\n", - task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags); put_cred(acred.cred); return ret; @@ -688,8 +679,6 @@ rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags) if (!acred.principal) return NULL; - dprintk("RPC: %5u looking up %s machine cred\n", - task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); return auth->au_ops->lookup_cred(auth, &acred, lookupflags); } @@ -698,8 +687,6 @@ rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) { struct rpc_auth *auth = task->tk_client->cl_auth; - dprintk("RPC: %5u looking up %s cred\n", - task->tk_pid, auth->au_ops->au_name); return rpcauth_lookupcred(auth, lookupflags); } @@ -771,75 +758,102 @@ destroy: } EXPORT_SYMBOL_GPL(put_rpccred); -__be32 * -rpcauth_marshcred(struct rpc_task *task, __be32 *p) +/** + * rpcauth_marshcred - Append RPC credential to end of @xdr + * @task: controlling RPC task + * @xdr: xdr_stream containing initial portion of RPC Call header + * + * On success, an appropriate verifier is added to @xdr, @xdr is + * updated to point past the verifier, and zero is returned. + * Otherwise, @xdr is in an undefined state and a negative errno + * is returned. + */ +int rpcauth_marshcred(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - dprintk("RPC: %5u marshaling %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); - - return cred->cr_ops->crmarshal(task, p); + return ops->crmarshal(task, xdr); } -__be32 * -rpcauth_checkverf(struct rpc_task *task, __be32 *p) +/** + * rpcauth_wrap_req_encode - XDR encode the RPC procedure + * @task: controlling RPC task + * @xdr: stream where on-the-wire bytes are to be marshalled + * + * On success, @xdr contains the encoded and wrapped message. + * Otherwise, @xdr is in an undefined state. + */ +int rpcauth_wrap_req_encode(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + kxdreproc_t encode = task->tk_msg.rpc_proc->p_encode; - dprintk("RPC: %5u validating %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); - - return cred->cr_ops->crvalidate(task, p); + encode(task->tk_rqstp, xdr, task->tk_msg.rpc_argp); + return 0; } +EXPORT_SYMBOL_GPL(rpcauth_wrap_req_encode); -static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *data, void *obj) +/** + * rpcauth_wrap_req - XDR encode and wrap the RPC procedure + * @task: controlling RPC task + * @xdr: stream where on-the-wire bytes are to be marshalled + * + * On success, @xdr contains the encoded and wrapped message, + * and zero is returned. Otherwise, @xdr is in an undefined + * state and a negative errno is returned. + */ +int rpcauth_wrap_req(struct rpc_task *task, struct xdr_stream *xdr) { - struct xdr_stream xdr; + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data); - encode(rqstp, &xdr, obj); + return ops->crwrap_req(task, xdr); } +/** + * rpcauth_checkverf - Validate verifier in RPC Reply header + * @task: controlling RPC task + * @xdr: xdr_stream containing RPC Reply header + * + * On success, @xdr is updated to point past the verifier and + * zero is returned. Otherwise, @xdr is in an undefined state + * and a negative errno is returned. + */ int -rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp, - __be32 *data, void *obj) +rpcauth_checkverf(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", - task->tk_pid, cred->cr_ops->cr_name, cred); - if (cred->cr_ops->crwrap_req) - return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); - /* By default, we encode the arguments normally. */ - rpcauth_wrap_req_encode(encode, rqstp, data, obj); - return 0; + return ops->crvalidate(task, xdr); } -static int -rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, - __be32 *data, void *obj) +/** + * rpcauth_unwrap_resp_decode - Invoke XDR decode function + * @task: controlling RPC task + * @xdr: stream where the Reply message resides + * + * Returns zero on success; otherwise a negative errno is returned. + */ +int +rpcauth_unwrap_resp_decode(struct rpc_task *task, struct xdr_stream *xdr) { - struct xdr_stream xdr; + kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; - xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data); - return decode(rqstp, &xdr, obj); + return decode(task->tk_rqstp, xdr, task->tk_msg.rpc_resp); } +EXPORT_SYMBOL_GPL(rpcauth_unwrap_resp_decode); +/** + * rpcauth_unwrap_resp - Invoke unwrap and decode function for the cred + * @task: controlling RPC task + * @xdr: stream where the Reply message resides + * + * Returns zero on success; otherwise a negative errno is returned. + */ int -rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, - __be32 *data, void *obj) +rpcauth_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; - dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", - task->tk_pid, cred->cr_ops->cr_name, cred); - if (cred->cr_ops->crunwrap_resp) - return cred->cr_ops->crunwrap_resp(task, decode, rqstp, - data, obj); - /* By default, we decode the arguments normally. */ - return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); + return ops->crunwrap_resp(task, xdr); } bool @@ -865,8 +879,6 @@ rpcauth_refreshcred(struct rpc_task *task) goto out; cred = task->tk_rqstp->rq_cred; } - dprintk("RPC: %5u refreshing %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); err = cred->cr_ops->crrefresh(task); out: @@ -880,8 +892,6 @@ rpcauth_invalcred(struct rpc_task *task) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; - dprintk("RPC: %5u invalidating %s cred %p\n", - task->tk_pid, cred->cr_auth->au_ops->au_name, cred); if (cred) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); } diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile index c374268b008f..4a29f4c5dac4 100644 --- a/net/sunrpc/auth_gss/Makefile +++ b/net/sunrpc/auth_gss/Makefile @@ -7,7 +7,7 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o auth_rpcgss-y := auth_gss.o gss_generic_token.o \ gss_mech_switch.o svcauth_gss.o \ - gss_rpc_upcall.o gss_rpc_xdr.o + gss_rpc_upcall.o gss_rpc_xdr.o trace.o obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index 1531b0219344..3fd56c0c90ae 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: BSD-3-Clause /* * linux/net/sunrpc/auth_gss/auth_gss.c * @@ -8,34 +9,8 @@ * * Dug Song <dugsong@monkey.org> * Andy Adamson <andros@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ - #include <linux/module.h> #include <linux/init.h> #include <linux/types.h> @@ -55,6 +30,8 @@ #include "../netns.h" +#include <trace/events/rpcgss.h> + static const struct rpc_authops authgss_ops; static const struct rpc_credops gss_credops; @@ -260,6 +237,7 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct } ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_NOFS); if (ret < 0) { + trace_rpcgss_import_ctx(ret); p = ERR_PTR(ret); goto err; } @@ -275,12 +253,9 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct if (IS_ERR(p)) goto err; done: - dprintk("RPC: %s Success. gc_expiry %lu now %lu timeout %u acceptor %.*s\n", - __func__, ctx->gc_expiry, now, timeout, ctx->gc_acceptor.len, - ctx->gc_acceptor.data); - return p; + trace_rpcgss_context(ctx->gc_expiry, now, timeout, + ctx->gc_acceptor.len, ctx->gc_acceptor.data); err: - dprintk("RPC: %s returns error %ld\n", __func__, -PTR_ERR(p)); return p; } @@ -354,10 +329,8 @@ __gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth if (auth && pos->auth->service != auth->service) continue; refcount_inc(&pos->count); - dprintk("RPC: %s found msg %p\n", __func__, pos); return pos; } - dprintk("RPC: %s found nothing\n", __func__); return NULL; } @@ -456,7 +429,7 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, size_t buflen = sizeof(gss_msg->databuf); int len; - len = scnprintf(p, buflen, "mech=%s uid=%d ", mech->gm_name, + len = scnprintf(p, buflen, "mech=%s uid=%d", mech->gm_name, from_kuid(&init_user_ns, gss_msg->uid)); buflen -= len; p += len; @@ -467,7 +440,7 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, * identity that we are authenticating to. */ if (target_name) { - len = scnprintf(p, buflen, "target=%s ", target_name); + len = scnprintf(p, buflen, " target=%s", target_name); buflen -= len; p += len; gss_msg->msg.len += len; @@ -487,11 +460,11 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, char *c = strchr(service_name, '@'); if (!c) - len = scnprintf(p, buflen, "service=%s ", + len = scnprintf(p, buflen, " service=%s", service_name); else len = scnprintf(p, buflen, - "service=%.*s srchost=%s ", + " service=%.*s srchost=%s", (int)(c - service_name), service_name, c + 1); buflen -= len; @@ -500,17 +473,17 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, } if (mech->gm_upcall_enctypes) { - len = scnprintf(p, buflen, "enctypes=%s ", + len = scnprintf(p, buflen, " enctypes=%s", mech->gm_upcall_enctypes); buflen -= len; p += len; gss_msg->msg.len += len; } + trace_rpcgss_upcall_msg(gss_msg->databuf); len = scnprintf(p, buflen, "\n"); if (len == 0) goto out_overflow; gss_msg->msg.len += len; - gss_msg->msg.data = gss_msg->databuf; return 0; out_overflow: @@ -603,8 +576,6 @@ gss_refresh_upcall(struct rpc_task *task) struct rpc_pipe *pipe; int err = 0; - dprintk("RPC: %5u %s for uid %u\n", - task->tk_pid, __func__, from_kuid(&init_user_ns, cred->cr_cred->fsuid)); gss_msg = gss_setup_upcall(gss_auth, cred); if (PTR_ERR(gss_msg) == -EAGAIN) { /* XXX: warning on the first, under the assumption we @@ -612,7 +583,8 @@ gss_refresh_upcall(struct rpc_task *task) warn_gssd(); task->tk_timeout = 15*HZ; rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL); - return -EAGAIN; + err = -EAGAIN; + goto out; } if (IS_ERR(gss_msg)) { err = PTR_ERR(gss_msg); @@ -635,9 +607,8 @@ gss_refresh_upcall(struct rpc_task *task) spin_unlock(&pipe->lock); gss_release_msg(gss_msg); out: - dprintk("RPC: %5u %s for uid %u result %d\n", - task->tk_pid, __func__, - from_kuid(&init_user_ns, cred->cr_cred->fsuid), err); + trace_rpcgss_upcall_result(from_kuid(&init_user_ns, + cred->cr_cred->fsuid), err); return err; } @@ -652,14 +623,13 @@ gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) DEFINE_WAIT(wait); int err; - dprintk("RPC: %s for uid %u\n", - __func__, from_kuid(&init_user_ns, cred->cr_cred->fsuid)); retry: err = 0; /* if gssd is down, just skip upcalling altogether */ if (!gssd_running(net)) { warn_gssd(); - return -EACCES; + err = -EACCES; + goto out; } gss_msg = gss_setup_upcall(gss_auth, cred); if (PTR_ERR(gss_msg) == -EAGAIN) { @@ -700,8 +670,8 @@ out_intr: finish_wait(&gss_msg->waitqueue, &wait); gss_release_msg(gss_msg); out: - dprintk("RPC: %s for uid %u result %d\n", - __func__, from_kuid(&init_user_ns, cred->cr_cred->fsuid), err); + trace_rpcgss_upcall_result(from_kuid(&init_user_ns, + cred->cr_cred->fsuid), err); return err; } @@ -794,7 +764,6 @@ err_put_ctx: err: kfree(buf); out: - dprintk("RPC: %s returning %zd\n", __func__, err); return err; } @@ -863,8 +832,6 @@ gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg); if (msg->errno < 0) { - dprintk("RPC: %s releasing msg %p\n", - __func__, gss_msg); refcount_inc(&gss_msg->count); gss_unhash_msg(gss_msg); if (msg->errno == -ETIMEDOUT) @@ -1024,8 +991,6 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) struct rpc_auth * auth; int err = -ENOMEM; /* XXX? */ - dprintk("RPC: creating GSS authenticator for client %p\n", clnt); - if (!try_module_get(THIS_MODULE)) return ERR_PTR(err); if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) @@ -1041,10 +1006,8 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) gss_auth->net = get_net(rpc_net_ns(clnt)); err = -EINVAL; gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); - if (!gss_auth->mech) { - dprintk("RPC: Pseudoflavor %d not found!\n", flavor); + if (!gss_auth->mech) goto err_put_net; - } gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor); if (gss_auth->service == 0) goto err_put_mech; @@ -1053,6 +1016,8 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) auth = &gss_auth->rpc_auth; auth->au_cslack = GSS_CRED_SLACK >> 2; auth->au_rslack = GSS_VERF_SLACK >> 2; + auth->au_verfsize = GSS_VERF_SLACK >> 2; + auth->au_ralign = GSS_VERF_SLACK >> 2; auth->au_flags = 0; auth->au_ops = &authgss_ops; auth->au_flavor = flavor; @@ -1099,6 +1064,7 @@ err_free: kfree(gss_auth); out_dec: module_put(THIS_MODULE); + trace_rpcgss_createauth(flavor, err); return ERR_PTR(err); } @@ -1135,9 +1101,6 @@ gss_destroy(struct rpc_auth *auth) struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); - dprintk("RPC: destroying GSS authenticator %p flavor %d\n", - auth, auth->au_flavor); - if (hash_hashed(&gss_auth->hash)) { spin_lock(&gss_auth_hash_lock); hash_del(&gss_auth->hash); @@ -1245,7 +1208,7 @@ gss_dup_cred(struct gss_auth *gss_auth, struct gss_cred *gss_cred) struct gss_cred *new; /* Make a copy of the cred so that we can reference count it */ - new = kzalloc(sizeof(*gss_cred), GFP_NOIO); + new = kzalloc(sizeof(*gss_cred), GFP_NOFS); if (new) { struct auth_cred acred = { .cred = gss_cred->gc_base.cr_cred, @@ -1300,8 +1263,6 @@ gss_send_destroy_context(struct rpc_cred *cred) static void gss_do_free_ctx(struct gss_cl_ctx *ctx) { - dprintk("RPC: %s\n", __func__); - gss_delete_sec_context(&ctx->gc_gss_ctx); kfree(ctx->gc_wire_ctx.data); kfree(ctx->gc_acceptor.data); @@ -1324,7 +1285,6 @@ gss_free_ctx(struct gss_cl_ctx *ctx) static void gss_free_cred(struct gss_cred *gss_cred) { - dprintk("RPC: %s cred=%p\n", __func__, gss_cred); kfree(gss_cred); } @@ -1381,10 +1341,6 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t struct gss_cred *cred = NULL; int err = -ENOMEM; - dprintk("RPC: %s for uid %d, flavor %d\n", - __func__, from_kuid(&init_user_ns, acred->cred->fsuid), - auth->au_flavor); - if (!(cred = kzalloc(sizeof(*cred), gfp))) goto out_err; @@ -1400,7 +1356,6 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t return &cred->gc_base; out_err: - dprintk("RPC: %s failed with error %d\n", __func__, err); return ERR_PTR(err); } @@ -1526,69 +1481,84 @@ out: } /* -* Marshal credentials. -* Maybe we should keep a cached credential for performance reasons. -*/ -static __be32 * -gss_marshal(struct rpc_task *task, __be32 *p) + * Marshal credentials. + * + * The expensive part is computing the verifier. We can't cache a + * pre-computed version of the verifier because the seqno, which + * is different every time, is included in the MIC. + */ +static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_cred *cred = req->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - __be32 *cred_len; + __be32 *p, *cred_len; u32 maj_stat = 0; struct xdr_netobj mic; struct kvec iov; struct xdr_buf verf_buf; + int status; - dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + /* Credential */ - *p++ = htonl(RPC_AUTH_GSS); + p = xdr_reserve_space(xdr, 7 * sizeof(*p) + + ctx->gc_wire_ctx.len); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_gss; cred_len = p++; spin_lock(&ctx->gc_seq_lock); req->rq_seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ; spin_unlock(&ctx->gc_seq_lock); if (req->rq_seqno == MAXSEQ) - goto out_expired; + goto expired; + trace_rpcgss_seqno(task); - *p++ = htonl((u32) RPC_GSS_VERSION); - *p++ = htonl((u32) ctx->gc_proc); - *p++ = htonl((u32) req->rq_seqno); - *p++ = htonl((u32) gss_cred->gc_service); + *p++ = cpu_to_be32(RPC_GSS_VERSION); + *p++ = cpu_to_be32(ctx->gc_proc); + *p++ = cpu_to_be32(req->rq_seqno); + *p++ = cpu_to_be32(gss_cred->gc_service); p = xdr_encode_netobj(p, &ctx->gc_wire_ctx); - *cred_len = htonl((p - (cred_len + 1)) << 2); + *cred_len = cpu_to_be32((p - (cred_len + 1)) << 2); + + /* Verifier */ /* We compute the checksum for the verifier over the xdr-encoded bytes * starting with the xid and ending at the end of the credential: */ - iov.iov_base = xprt_skip_transport_header(req->rq_xprt, - req->rq_snd_buf.head[0].iov_base); + iov.iov_base = req->rq_snd_buf.head[0].iov_base; iov.iov_len = (u8 *)p - (u8 *)iov.iov_base; xdr_buf_from_iov(&iov, &verf_buf); - /* set verifier flavor*/ - *p++ = htonl(RPC_AUTH_GSS); - + p = xdr_reserve_space(xdr, sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_gss; mic.data = (u8 *)(p + 1); maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic); - if (maj_stat == GSS_S_CONTEXT_EXPIRED) { - goto out_expired; - } else if (maj_stat != 0) { - pr_warn("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat); - task->tk_status = -EIO; - goto out_put_ctx; - } - p = xdr_encode_opaque(p, NULL, mic.len); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + goto expired; + else if (maj_stat != 0) + goto bad_mic; + if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0) + goto marshal_failed; + status = 0; +out: gss_put_ctx(ctx); - return p; -out_expired: + return status; +expired: clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); - task->tk_status = -EKEYEXPIRED; -out_put_ctx: - gss_put_ctx(ctx); - return NULL; + status = -EKEYEXPIRED; + goto out; +marshal_failed: + status = -EMSGSIZE; + goto out; +bad_mic: + trace_rpcgss_get_mic(task, maj_stat); + status = -EIO; + goto out; } static int gss_renew_cred(struct rpc_task *task) @@ -1662,116 +1632,105 @@ gss_refresh_null(struct rpc_task *task) return 0; } -static __be32 * -gss_validate(struct rpc_task *task, __be32 *p) +static int +gss_validate(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - __be32 *seq = NULL; + __be32 *p, *seq = NULL; struct kvec iov; struct xdr_buf verf_buf; struct xdr_netobj mic; - u32 flav,len; - u32 maj_stat; - __be32 *ret = ERR_PTR(-EIO); + u32 len, maj_stat; + int status; - dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + goto validate_failed; + if (*p++ != rpc_auth_gss) + goto validate_failed; + len = be32_to_cpup(p); + if (len > RPC_MAX_AUTH_SIZE) + goto validate_failed; + p = xdr_inline_decode(xdr, len); + if (!p) + goto validate_failed; - flav = ntohl(*p++); - if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE) - goto out_bad; - if (flav != RPC_AUTH_GSS) - goto out_bad; seq = kmalloc(4, GFP_NOFS); if (!seq) - goto out_bad; - *seq = htonl(task->tk_rqstp->rq_seqno); + goto validate_failed; + *seq = cpu_to_be32(task->tk_rqstp->rq_seqno); iov.iov_base = seq; iov.iov_len = 4; xdr_buf_from_iov(&iov, &verf_buf); mic.data = (u8 *)p; mic.len = len; - - ret = ERR_PTR(-EACCES); maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); - if (maj_stat) { - dprintk("RPC: %5u %s: gss_verify_mic returned error 0x%08x\n", - task->tk_pid, __func__, maj_stat); - goto out_bad; - } + if (maj_stat) + goto bad_mic; + /* We leave it to unwrap to calculate au_rslack. For now we just * calculate the length of the verifier: */ cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; + status = 0; +out: gss_put_ctx(ctx); - dprintk("RPC: %5u %s: gss_verify_mic succeeded.\n", - task->tk_pid, __func__); - kfree(seq); - return p + XDR_QUADLEN(len); -out_bad: - gss_put_ctx(ctx); - dprintk("RPC: %5u %s failed ret %ld.\n", task->tk_pid, __func__, - PTR_ERR(ret)); kfree(seq); - return ret; -} - -static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) -{ - struct xdr_stream xdr; + return status; - xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p); - encode(rqstp, &xdr, obj); +validate_failed: + status = -EIO; + goto out; +bad_mic: + trace_rpcgss_verify_mic(task, maj_stat); + status = -EACCES; + goto out; } -static inline int -gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) +static int gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_task *task, struct xdr_stream *xdr) { - struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; - struct xdr_buf integ_buf; - __be32 *integ_len = NULL; + struct rpc_rqst *rqstp = task->tk_rqstp; + struct xdr_buf integ_buf, *snd_buf = &rqstp->rq_snd_buf; struct xdr_netobj mic; - u32 offset; - __be32 *q; - struct kvec *iov; - u32 maj_stat = 0; - int status = -EIO; + __be32 *p, *integ_len; + u32 offset, maj_stat; + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto wrap_failed; integ_len = p++; - offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; - *p++ = htonl(rqstp->rq_seqno); + *p = cpu_to_be32(rqstp->rq_seqno); - gss_wrap_req_encode(encode, rqstp, p, obj); + if (rpcauth_wrap_req_encode(task, xdr)) + goto wrap_failed; + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; if (xdr_buf_subsegment(snd_buf, &integ_buf, offset, snd_buf->len - offset)) - return status; - *integ_len = htonl(integ_buf.len); + goto wrap_failed; + *integ_len = cpu_to_be32(integ_buf.len); - /* guess whether we're in the head or the tail: */ - if (snd_buf->page_len || snd_buf->tail[0].iov_len) - iov = snd_buf->tail; - else - iov = snd_buf->head; - p = iov->iov_base + iov->iov_len; + p = xdr_reserve_space(xdr, 0); + if (!p) + goto wrap_failed; mic.data = (u8 *)(p + 1); - maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic); - status = -EIO; /* XXX? */ if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); else if (maj_stat) - return status; - q = xdr_encode_opaque(p, NULL, mic.len); - - offset = (u8 *)q - (u8 *)p; - iov->iov_len += offset; - snd_buf->len += offset; + goto bad_mic; + /* Check that the trailing MIC fit in the buffer, after the fact */ + if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0) + goto wrap_failed; return 0; +wrap_failed: + return -EMSGSIZE; +bad_mic: + trace_rpcgss_get_mic(task, maj_stat); + return -EIO; } static void @@ -1822,61 +1781,62 @@ out: return -EAGAIN; } -static inline int -gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - kxdreproc_t encode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) +static int gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_task *task, struct xdr_stream *xdr) { + struct rpc_rqst *rqstp = task->tk_rqstp; struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; - u32 offset; - u32 maj_stat; + u32 pad, offset, maj_stat; int status; - __be32 *opaque_len; + __be32 *p, *opaque_len; struct page **inpages; int first; - int pad; struct kvec *iov; - char *tmp; + status = -EIO; + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto wrap_failed; opaque_len = p++; - offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; - *p++ = htonl(rqstp->rq_seqno); + *p = cpu_to_be32(rqstp->rq_seqno); - gss_wrap_req_encode(encode, rqstp, p, obj); + if (rpcauth_wrap_req_encode(task, xdr)) + goto wrap_failed; status = alloc_enc_pages(rqstp); - if (status) - return status; + if (unlikely(status)) + goto wrap_failed; first = snd_buf->page_base >> PAGE_SHIFT; inpages = snd_buf->pages + first; snd_buf->pages = rqstp->rq_enc_pages; snd_buf->page_base -= first << PAGE_SHIFT; /* - * Give the tail its own page, in case we need extra space in the - * head when wrapping: + * Move the tail into its own page, in case gss_wrap needs + * more space in the head when wrapping. * - * call_allocate() allocates twice the slack space required - * by the authentication flavor to rq_callsize. - * For GSS, slack is GSS_CRED_SLACK. + * Still... Why can't gss_wrap just slide the tail down? */ if (snd_buf->page_len || snd_buf->tail[0].iov_len) { + char *tmp; + tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]); memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len); snd_buf->tail[0].iov_base = tmp; } + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); /* slack space should prevent this ever happening: */ - BUG_ON(snd_buf->len > snd_buf->buflen); - status = -EIO; + if (unlikely(snd_buf->len > snd_buf->buflen)) + goto wrap_failed; /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was * done anyway, so it's safe to put the request on the wire: */ if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); else if (maj_stat) - return status; + goto bad_wrap; - *opaque_len = htonl(snd_buf->len - offset); - /* guess whether we're in the head or the tail: */ + *opaque_len = cpu_to_be32(snd_buf->len - offset); + /* guess whether the pad goes into the head or the tail: */ if (snd_buf->page_len || snd_buf->tail[0].iov_len) iov = snd_buf->tail; else @@ -1888,118 +1848,154 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, snd_buf->len += pad; return 0; +wrap_failed: + return status; +bad_wrap: + trace_rpcgss_wrap(task, maj_stat); + return -EIO; } -static int -gss_wrap_req(struct rpc_task *task, - kxdreproc_t encode, void *rqstp, __be32 *p, void *obj) +static int gss_wrap_req(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_cred *cred = task->tk_rqstp->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - int status = -EIO; + int status; - dprintk("RPC: %5u %s\n", task->tk_pid, __func__); + status = -EIO; if (ctx->gc_proc != RPC_GSS_PROC_DATA) { /* The spec seems a little ambiguous here, but I think that not * wrapping context destruction requests makes the most sense. */ - gss_wrap_req_encode(encode, rqstp, p, obj); - status = 0; + status = rpcauth_wrap_req_encode(task, xdr); goto out; } switch (gss_cred->gc_service) { case RPC_GSS_SVC_NONE: - gss_wrap_req_encode(encode, rqstp, p, obj); - status = 0; + status = rpcauth_wrap_req_encode(task, xdr); break; case RPC_GSS_SVC_INTEGRITY: - status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj); + status = gss_wrap_req_integ(cred, ctx, task, xdr); break; case RPC_GSS_SVC_PRIVACY: - status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj); + status = gss_wrap_req_priv(cred, ctx, task, xdr); break; + default: + status = -EIO; } out: gss_put_ctx(ctx); - dprintk("RPC: %5u %s returning %d\n", task->tk_pid, __func__, status); return status; } -static inline int -gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - struct rpc_rqst *rqstp, __be32 **p) +static int +gss_unwrap_resp_auth(struct rpc_cred *cred) { - struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; - struct xdr_buf integ_buf; + struct rpc_auth *auth = cred->cr_auth; + + auth->au_rslack = auth->au_verfsize; + auth->au_ralign = auth->au_verfsize; + return 0; +} + +static int +gss_unwrap_resp_integ(struct rpc_task *task, struct rpc_cred *cred, + struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp, + struct xdr_stream *xdr) +{ + struct xdr_buf integ_buf, *rcv_buf = &rqstp->rq_rcv_buf; + u32 data_offset, mic_offset, integ_len, maj_stat; + struct rpc_auth *auth = cred->cr_auth; struct xdr_netobj mic; - u32 data_offset, mic_offset; - u32 integ_len; - u32 maj_stat; - int status = -EIO; + __be32 *p; - integ_len = ntohl(*(*p)++); + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (unlikely(!p)) + goto unwrap_failed; + integ_len = be32_to_cpup(p++); if (integ_len & 3) - return status; - data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; + goto unwrap_failed; + data_offset = (u8 *)(p) - (u8 *)rcv_buf->head[0].iov_base; mic_offset = integ_len + data_offset; if (mic_offset > rcv_buf->len) - return status; - if (ntohl(*(*p)++) != rqstp->rq_seqno) - return status; - - if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset, - mic_offset - data_offset)) - return status; + goto unwrap_failed; + if (be32_to_cpup(p) != rqstp->rq_seqno) + goto bad_seqno; + if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset, integ_len)) + goto unwrap_failed; if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset)) - return status; - + goto unwrap_failed; maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); if (maj_stat != GSS_S_COMPLETE) - return status; + goto bad_mic; + + auth->au_rslack = auth->au_verfsize + 2 + 1 + XDR_QUADLEN(mic.len); + auth->au_ralign = auth->au_verfsize + 2; return 0; +unwrap_failed: + trace_rpcgss_unwrap_failed(task); + return -EIO; +bad_seqno: + trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(p)); + return -EIO; +bad_mic: + trace_rpcgss_verify_mic(task, maj_stat); + return -EIO; } -static inline int -gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, - struct rpc_rqst *rqstp, __be32 **p) -{ - struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; - u32 offset; - u32 opaque_len; - u32 maj_stat; - int status = -EIO; - - opaque_len = ntohl(*(*p)++); - offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; +static int +gss_unwrap_resp_priv(struct rpc_task *task, struct rpc_cred *cred, + struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp, + struct xdr_stream *xdr) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + struct kvec *head = rqstp->rq_rcv_buf.head; + struct rpc_auth *auth = cred->cr_auth; + unsigned int savedlen = rcv_buf->len; + u32 offset, opaque_len, maj_stat; + __be32 *p; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (unlikely(!p)) + goto unwrap_failed; + opaque_len = be32_to_cpup(p++); + offset = (u8 *)(p) - (u8 *)head->iov_base; if (offset + opaque_len > rcv_buf->len) - return status; - /* remove padding: */ + goto unwrap_failed; rcv_buf->len = offset + opaque_len; maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf); if (maj_stat == GSS_S_CONTEXT_EXPIRED) clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); if (maj_stat != GSS_S_COMPLETE) - return status; - if (ntohl(*(*p)++) != rqstp->rq_seqno) - return status; + goto bad_unwrap; + /* gss_unwrap decrypted the sequence number */ + if (be32_to_cpup(p++) != rqstp->rq_seqno) + goto bad_seqno; - return 0; -} - -static int -gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, - __be32 *p, void *obj) -{ - struct xdr_stream xdr; + /* gss_unwrap redacts the opaque blob from the head iovec. + * rcv_buf has changed, thus the stream needs to be reset. + */ + xdr_init_decode(xdr, rcv_buf, p, rqstp); - xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p); - return decode(rqstp, &xdr, obj); + auth->au_rslack = auth->au_verfsize + 2 + + XDR_QUADLEN(savedlen - rcv_buf->len); + auth->au_ralign = auth->au_verfsize + 2 + + XDR_QUADLEN(savedlen - rcv_buf->len); + return 0; +unwrap_failed: + trace_rpcgss_unwrap_failed(task); + return -EIO; +bad_seqno: + trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(--p)); + return -EIO; +bad_unwrap: + trace_rpcgss_unwrap(task, maj_stat); + return -EIO; } static bool @@ -2014,14 +2010,14 @@ gss_xmit_need_reencode(struct rpc_task *task) struct rpc_rqst *req = task->tk_rqstp; struct rpc_cred *cred = req->rq_cred; struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - u32 win, seq_xmit; + u32 win, seq_xmit = 0; bool ret = true; if (!ctx) - return true; + goto out; if (gss_seq_is_newer(req->rq_seqno, READ_ONCE(ctx->gc_seq))) - goto out; + goto out_ctx; seq_xmit = READ_ONCE(ctx->gc_seq_xmit); while (gss_seq_is_newer(req->rq_seqno, seq_xmit)) { @@ -2030,56 +2026,51 @@ gss_xmit_need_reencode(struct rpc_task *task) seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, req->rq_seqno); if (seq_xmit == tmp) { ret = false; - goto out; + goto out_ctx; } } win = ctx->gc_win; if (win > 0) ret = !gss_seq_is_newer(req->rq_seqno, seq_xmit - win); -out: + +out_ctx: gss_put_ctx(ctx); +out: + trace_rpcgss_need_reencode(task, seq_xmit, ret); return ret; } static int -gss_unwrap_resp(struct rpc_task *task, - kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj) +gss_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr) { - struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct rpc_rqst *rqstp = task->tk_rqstp; + struct rpc_cred *cred = rqstp->rq_cred; struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); - __be32 *savedp = p; - struct kvec *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head; - int savedlen = head->iov_len; - int status = -EIO; + int status = -EIO; if (ctx->gc_proc != RPC_GSS_PROC_DATA) goto out_decode; switch (gss_cred->gc_service) { case RPC_GSS_SVC_NONE: + status = gss_unwrap_resp_auth(cred); break; case RPC_GSS_SVC_INTEGRITY: - status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p); - if (status) - goto out; + status = gss_unwrap_resp_integ(task, cred, ctx, rqstp, xdr); break; case RPC_GSS_SVC_PRIVACY: - status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p); - if (status) - goto out; + status = gss_unwrap_resp_priv(task, cred, ctx, rqstp, xdr); break; } - /* take into account extra slack for integrity and privacy cases: */ - cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp) - + (savedlen - head->iov_len); + if (status) + goto out; + out_decode: - status = gss_unwrap_req_decode(decode, rqstp, p, obj); + status = rpcauth_unwrap_resp_decode(task, xdr); out: gss_put_ctx(ctx); - dprintk("RPC: %5u %s returning %d\n", - task->tk_pid, __func__, status); return status; } diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c index eab71fc7af3e..56cc85c5bc06 100644 --- a/net/sunrpc/auth_gss/gss_krb5_mech.c +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: BSD-3-Clause /* * linux/net/sunrpc/gss_krb5_mech.c * @@ -6,32 +7,6 @@ * * Andy Adamson <andros@umich.edu> * J. Bruce Fields <bfields@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * */ #include <crypto/hash.h> @@ -53,6 +28,7 @@ static struct gss_api_mech gss_kerberos_mech; /* forward declaration */ static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = { +#ifndef CONFIG_SUNRPC_DISABLE_INSECURE_ENCTYPES /* * DES (All DES enctypes are mapped to the same gss functionality) */ @@ -74,6 +50,7 @@ static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = { .cksumlength = 8, .keyed_cksum = 0, }, +#endif /* CONFIG_SUNRPC_DISABLE_INSECURE_ENCTYPES */ /* * RC4-HMAC */ diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c index fb6656295204..507105127095 100644 --- a/net/sunrpc/auth_gss/gss_krb5_seqnum.c +++ b/net/sunrpc/auth_gss/gss_krb5_seqnum.c @@ -44,7 +44,7 @@ krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, unsigned char *cksum, unsigned char *buf) { struct crypto_sync_skcipher *cipher; - unsigned char plain[8]; + unsigned char *plain; s32 code; dprintk("RPC: %s:\n", __func__); @@ -52,6 +52,10 @@ krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, if (IS_ERR(cipher)) return PTR_ERR(cipher); + plain = kmalloc(8, GFP_NOFS); + if (!plain) + return -ENOMEM; + plain[0] = (unsigned char) ((seqnum >> 24) & 0xff); plain[1] = (unsigned char) ((seqnum >> 16) & 0xff); plain[2] = (unsigned char) ((seqnum >> 8) & 0xff); @@ -67,6 +71,7 @@ krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, code = krb5_encrypt(cipher, cksum, plain, buf, 8); out: + kfree(plain); crypto_free_sync_skcipher(cipher); return code; } @@ -77,12 +82,17 @@ krb5_make_seq_num(struct krb5_ctx *kctx, u32 seqnum, unsigned char *cksum, unsigned char *buf) { - unsigned char plain[8]; + unsigned char *plain; + s32 code; if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) return krb5_make_rc4_seq_num(kctx, direction, seqnum, cksum, buf); + plain = kmalloc(8, GFP_NOFS); + if (!plain) + return -ENOMEM; + plain[0] = (unsigned char) (seqnum & 0xff); plain[1] = (unsigned char) ((seqnum >> 8) & 0xff); plain[2] = (unsigned char) ((seqnum >> 16) & 0xff); @@ -93,7 +103,9 @@ krb5_make_seq_num(struct krb5_ctx *kctx, plain[6] = direction; plain[7] = direction; - return krb5_encrypt(key, cksum, plain, buf, 8); + code = krb5_encrypt(key, cksum, plain, buf, 8); + kfree(plain); + return code; } static s32 @@ -101,7 +113,7 @@ krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, unsigned char *buf, int *direction, s32 *seqnum) { struct crypto_sync_skcipher *cipher; - unsigned char plain[8]; + unsigned char *plain; s32 code; dprintk("RPC: %s:\n", __func__); @@ -113,20 +125,28 @@ krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, if (code) goto out; + plain = kmalloc(8, GFP_NOFS); + if (!plain) { + code = -ENOMEM; + goto out; + } + code = krb5_decrypt(cipher, cksum, buf, plain, 8); if (code) - goto out; + goto out_plain; if ((plain[4] != plain[5]) || (plain[4] != plain[6]) || (plain[4] != plain[7])) { code = (s32)KG_BAD_SEQ; - goto out; + goto out_plain; } *direction = plain[4]; *seqnum = ((plain[0] << 24) | (plain[1] << 16) | (plain[2] << 8) | (plain[3])); +out_plain: + kfree(plain); out: crypto_free_sync_skcipher(cipher); return code; @@ -139,7 +159,7 @@ krb5_get_seq_num(struct krb5_ctx *kctx, int *direction, u32 *seqnum) { s32 code; - unsigned char plain[8]; + unsigned char *plain; struct crypto_sync_skcipher *key = kctx->seq; dprintk("RPC: krb5_get_seq_num:\n"); @@ -147,18 +167,25 @@ krb5_get_seq_num(struct krb5_ctx *kctx, if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) return krb5_get_rc4_seq_num(kctx, cksum, buf, direction, seqnum); + plain = kmalloc(8, GFP_NOFS); + if (!plain) + return -ENOMEM; if ((code = krb5_decrypt(key, cksum, buf, plain, 8))) - return code; + goto out; if ((plain[4] != plain[5]) || (plain[4] != plain[6]) || - (plain[4] != plain[7])) - return (s32)KG_BAD_SEQ; + (plain[4] != plain[7])) { + code = (s32)KG_BAD_SEQ; + goto out; + } *direction = plain[4]; *seqnum = ((plain[0]) | (plain[1] << 8) | (plain[2] << 16) | (plain[3] << 24)); - return 0; +out: + kfree(plain); + return code; } diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c index 5cdde6cb703a..14a0aff0cd84 100644 --- a/net/sunrpc/auth_gss/gss_krb5_wrap.c +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -570,14 +570,16 @@ gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) */ movelen = min_t(unsigned int, buf->head[0].iov_len, buf->len); movelen -= offset + GSS_KRB5_TOK_HDR_LEN + headskip; - BUG_ON(offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen > - buf->head[0].iov_len); + if (offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen > + buf->head[0].iov_len) + return GSS_S_FAILURE; memmove(ptr, ptr + GSS_KRB5_TOK_HDR_LEN + headskip, movelen); buf->head[0].iov_len -= GSS_KRB5_TOK_HDR_LEN + headskip; buf->len -= GSS_KRB5_TOK_HDR_LEN + headskip; /* Trim off the trailing "extra count" and checksum blob */ - xdr_buf_trim(buf, ec + GSS_KRB5_TOK_HDR_LEN + tailskip); + buf->len -= ec + GSS_KRB5_TOK_HDR_LEN + tailskip; + return GSS_S_COMPLETE; } diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c index 379318dff534..82060099a429 100644 --- a/net/sunrpc/auth_gss/gss_mech_switch.c +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: BSD-3-Clause /* * linux/net/sunrpc/gss_mech_switch.c * @@ -5,32 +6,6 @@ * All rights reserved. * * J. Bruce Fields <bfields@umich.edu> - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. Neither the name of the University nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED - * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR - * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * */ #include <linux/types.h> diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.c b/net/sunrpc/auth_gss/gss_rpc_upcall.c index 73dcda060335..0349f455a862 100644 --- a/net/sunrpc/auth_gss/gss_rpc_upcall.c +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.c @@ -1,21 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ /* * linux/net/sunrpc/gss_rpc_upcall.c * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #include <linux/types.h> diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.h b/net/sunrpc/auth_gss/gss_rpc_upcall.h index 1e542aded90a..31e96344167e 100644 --- a/net/sunrpc/auth_gss/gss_rpc_upcall.h +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.h @@ -1,21 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ /* * linux/net/sunrpc/gss_rpc_upcall.h * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #ifndef _GSS_RPC_UPCALL_H @@ -45,4 +32,5 @@ void gssp_free_upcall_data(struct gssp_upcall_data *data); void init_gssp_clnt(struct sunrpc_net *); int set_gssp_clnt(struct net *); void clear_gssp_clnt(struct sunrpc_net *); + #endif /* _GSS_RPC_UPCALL_H */ diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c index 006062ad5f58..2ff7b7083eba 100644 --- a/net/sunrpc/auth_gss/gss_rpc_xdr.c +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c @@ -1,21 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0+ /* * GSS Proxy upcall module * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #include <linux/sunrpc/svcauth.h> diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.h b/net/sunrpc/auth_gss/gss_rpc_xdr.h index 146c31032917..3f17411b7e65 100644 --- a/net/sunrpc/auth_gss/gss_rpc_xdr.h +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.h @@ -1,21 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ /* * GSS Proxy upcall module * * Copyright (C) 2012 Simo Sorce <simo@redhat.com> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to the Free Software - * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ #ifndef _LINUX_GSS_RPC_XDR_H @@ -262,6 +249,4 @@ int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, #define GSSX_ARG_wrap_size_limit_sz 0 #define GSSX_RES_wrap_size_limit_sz 0 - - #endif /* _LINUX_GSS_RPC_XDR_H */ diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c index 152790ed309c..0c5d7896d6dd 100644 --- a/net/sunrpc/auth_gss/svcauth_gss.c +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: GPL-2.0 /* * Neil Brown <neilb@cse.unsw.edu.au> * J. Bruce Fields <bfields@umich.edu> @@ -896,7 +897,7 @@ unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct g if (svc_getnl(&buf->head[0]) != seq) goto out; /* trim off the mic and padding at the end before returning */ - xdr_buf_trim(buf, round_up_to_quad(mic.len) + 4); + buf->len -= 4 + round_up_to_quad(mic.len); stat = 0; out: kfree(mic.data); diff --git a/net/sunrpc/auth_gss/trace.c b/net/sunrpc/auth_gss/trace.c new file mode 100644 index 000000000000..5576f1e66de9 --- /dev/null +++ b/net/sunrpc/auth_gss/trace.c @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2018, 2019 Oracle. All rights reserved. + */ + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/gss_err.h> + +#define CREATE_TRACE_POINTS +#include <trace/events/rpcgss.h> diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c index d0ceac57c06e..41a633a4049e 100644 --- a/net/sunrpc/auth_null.c +++ b/net/sunrpc/auth_null.c @@ -59,15 +59,21 @@ nul_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) /* * Marshal credential. */ -static __be32 * -nul_marshal(struct rpc_task *task, __be32 *p) +static int +nul_marshal(struct rpc_task *task, struct xdr_stream *xdr) { - *p++ = htonl(RPC_AUTH_NULL); - *p++ = 0; - *p++ = htonl(RPC_AUTH_NULL); - *p++ = 0; - - return p; + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (!p) + return -EMSGSIZE; + /* Credential */ + *p++ = rpc_auth_null; + *p++ = xdr_zero; + /* Verifier */ + *p++ = rpc_auth_null; + *p = xdr_zero; + return 0; } /* @@ -80,25 +86,19 @@ nul_refresh(struct rpc_task *task) return 0; } -static __be32 * -nul_validate(struct rpc_task *task, __be32 *p) +static int +nul_validate(struct rpc_task *task, struct xdr_stream *xdr) { - rpc_authflavor_t flavor; - u32 size; - - flavor = ntohl(*p++); - if (flavor != RPC_AUTH_NULL) { - printk("RPC: bad verf flavor: %u\n", flavor); - return ERR_PTR(-EIO); - } - - size = ntohl(*p++); - if (size != 0) { - printk("RPC: bad verf size: %u\n", size); - return ERR_PTR(-EIO); - } - - return p; + __be32 *p; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + return -EIO; + if (*p++ != rpc_auth_null) + return -EIO; + if (*p != xdr_zero) + return -EIO; + return 0; } const struct rpc_authops authnull_ops = { @@ -114,6 +114,8 @@ static struct rpc_auth null_auth = { .au_cslack = NUL_CALLSLACK, .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, + .au_ralign = NUL_REPLYSLACK, .au_ops = &authnull_ops, .au_flavor = RPC_AUTH_NULL, .au_count = REFCOUNT_INIT(1), @@ -125,8 +127,10 @@ const struct rpc_credops null_credops = { .crdestroy = nul_destroy_cred, .crmatch = nul_match, .crmarshal = nul_marshal, + .crwrap_req = rpcauth_wrap_req_encode, .crrefresh = nul_refresh, .crvalidate = nul_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, }; static diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c index 387f6b3ffbea..d4018e5a24c5 100644 --- a/net/sunrpc/auth_unix.c +++ b/net/sunrpc/auth_unix.c @@ -28,8 +28,6 @@ static mempool_t *unix_pool; static struct rpc_auth * unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { - dprintk("RPC: creating UNIX authenticator for client %p\n", - clnt); refcount_inc(&unix_auth.au_count); return &unix_auth; } @@ -37,7 +35,6 @@ unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) static void unx_destroy(struct rpc_auth *auth) { - dprintk("RPC: destroying UNIX authenticator %p\n", auth); } /* @@ -48,10 +45,6 @@ unx_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) { struct rpc_cred *ret = mempool_alloc(unix_pool, GFP_NOFS); - dprintk("RPC: allocating UNIX cred for uid %d gid %d\n", - from_kuid(&init_user_ns, acred->cred->fsuid), - from_kgid(&init_user_ns, acred->cred->fsgid)); - rpcauth_init_cred(ret, acred, auth, &unix_credops); ret->cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; return ret; @@ -61,7 +54,7 @@ static void unx_free_cred_callback(struct rcu_head *head) { struct rpc_cred *rpc_cred = container_of(head, struct rpc_cred, cr_rcu); - dprintk("RPC: unx_free_cred %p\n", rpc_cred); + put_cred(rpc_cred->cr_cred); mempool_free(rpc_cred, unix_pool); } @@ -87,7 +80,7 @@ unx_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) if (!uid_eq(cred->cr_cred->fsuid, acred->cred->fsuid) || !gid_eq(cred->cr_cred->fsgid, acred->cred->fsgid)) return 0; - if (acred->cred && acred->cred->group_info != NULL) + if (acred->cred->group_info != NULL) groups = acred->cred->group_info->ngroups; if (groups > UNX_NGROUPS) groups = UNX_NGROUPS; @@ -106,37 +99,55 @@ unx_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) * Marshal credentials. * Maybe we should keep a cached credential for performance reasons. */ -static __be32 * -unx_marshal(struct rpc_task *task, __be32 *p) +static int +unx_marshal(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_clnt *clnt = task->tk_client; struct rpc_cred *cred = task->tk_rqstp->rq_cred; - __be32 *base, *hold; + __be32 *p, *cred_len, *gidarr_len; int i; struct group_info *gi = cred->cr_cred->group_info; - *p++ = htonl(RPC_AUTH_UNIX); - base = p++; - *p++ = htonl(jiffies/HZ); - - /* - * Copy the UTS nodename captured when the client was created. - */ - p = xdr_encode_array(p, clnt->cl_nodename, clnt->cl_nodelen); - - *p++ = htonl((u32) from_kuid(&init_user_ns, cred->cr_cred->fsuid)); - *p++ = htonl((u32) from_kgid(&init_user_ns, cred->cr_cred->fsgid)); - hold = p++; + /* Credential */ + + p = xdr_reserve_space(xdr, 3 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_unix; + cred_len = p++; + *p++ = xdr_zero; /* stamp */ + if (xdr_stream_encode_opaque(xdr, clnt->cl_nodename, + clnt->cl_nodelen) < 0) + goto marshal_failed; + p = xdr_reserve_space(xdr, 3 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = cpu_to_be32(from_kuid(&init_user_ns, cred->cr_cred->fsuid)); + *p++ = cpu_to_be32(from_kgid(&init_user_ns, cred->cr_cred->fsgid)); + + gidarr_len = p++; if (gi) for (i = 0; i < UNX_NGROUPS && i < gi->ngroups; i++) - *p++ = htonl((u32) from_kgid(&init_user_ns, gi->gid[i])); - *hold = htonl(p - hold - 1); /* gid array length */ - *base = htonl((p - base - 1) << 2); /* cred length */ + *p++ = cpu_to_be32(from_kgid(&init_user_ns, + gi->gid[i])); + *gidarr_len = cpu_to_be32(p - gidarr_len - 1); + *cred_len = cpu_to_be32((p - cred_len - 1) << 2); + p = xdr_reserve_space(xdr, (p - gidarr_len - 1) << 2); + if (!p) + goto marshal_failed; + + /* Verifier */ + + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_null; + *p = xdr_zero; - *p++ = htonl(RPC_AUTH_NULL); - *p++ = htonl(0); + return 0; - return p; +marshal_failed: + return -EMSGSIZE; } /* @@ -149,29 +160,35 @@ unx_refresh(struct rpc_task *task) return 0; } -static __be32 * -unx_validate(struct rpc_task *task, __be32 *p) +static int +unx_validate(struct rpc_task *task, struct xdr_stream *xdr) { - rpc_authflavor_t flavor; - u32 size; - - flavor = ntohl(*p++); - if (flavor != RPC_AUTH_NULL && - flavor != RPC_AUTH_UNIX && - flavor != RPC_AUTH_SHORT) { - printk("RPC: bad verf flavor: %u\n", flavor); - return ERR_PTR(-EIO); - } - - size = ntohl(*p++); - if (size > RPC_MAX_AUTH_SIZE) { - printk("RPC: giant verf size: %u\n", size); - return ERR_PTR(-EIO); + struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth; + __be32 *p; + u32 size; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + return -EIO; + switch (*p++) { + case rpc_auth_null: + case rpc_auth_unix: + case rpc_auth_short: + break; + default: + return -EIO; } - task->tk_rqstp->rq_cred->cr_auth->au_rslack = (size >> 2) + 2; - p += (size >> 2); - - return p; + size = be32_to_cpup(p); + if (size > RPC_MAX_AUTH_SIZE) + return -EIO; + p = xdr_inline_decode(xdr, size); + if (!p) + return -EIO; + + auth->au_verfsize = XDR_QUADLEN(size) + 2; + auth->au_rslack = XDR_QUADLEN(size) + 2; + auth->au_ralign = XDR_QUADLEN(size) + 2; + return 0; } int __init rpc_init_authunix(void) @@ -198,6 +215,7 @@ static struct rpc_auth unix_auth = { .au_cslack = UNX_CALLSLACK, .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, .au_ops = &authunix_ops, .au_flavor = RPC_AUTH_UNIX, .au_count = REFCOUNT_INIT(1), @@ -209,6 +227,8 @@ const struct rpc_credops unix_credops = { .crdestroy = unx_destroy_cred, .crmatch = unx_match, .crmarshal = unx_marshal, + .crwrap_req = rpcauth_wrap_req_encode, .crrefresh = unx_refresh, .crvalidate = unx_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, }; diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c index ec451b8114b0..c47d82622fd1 100644 --- a/net/sunrpc/backchannel_rqst.c +++ b/net/sunrpc/backchannel_rqst.c @@ -235,7 +235,8 @@ out: list_empty(&xprt->bc_pa_list) ? "true" : "false"); } -static struct rpc_rqst *xprt_alloc_bc_request(struct rpc_xprt *xprt, __be32 xid) +static struct rpc_rqst *xprt_get_bc_request(struct rpc_xprt *xprt, __be32 xid, + struct rpc_rqst *new) { struct rpc_rqst *req = NULL; @@ -243,22 +244,20 @@ static struct rpc_rqst *xprt_alloc_bc_request(struct rpc_xprt *xprt, __be32 xid) if (atomic_read(&xprt->bc_free_slots) <= 0) goto not_found; if (list_empty(&xprt->bc_pa_list)) { - req = xprt_alloc_bc_req(xprt, GFP_ATOMIC); - if (!req) + if (!new) goto not_found; - list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list); + list_add_tail(&new->rq_bc_pa_list, &xprt->bc_pa_list); xprt->bc_alloc_count++; } req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, rq_bc_pa_list); req->rq_reply_bytes_recvd = 0; - req->rq_bytes_sent = 0; memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(req->rq_private_buf)); req->rq_xid = xid; req->rq_connect_cookie = xprt->connect_cookie; -not_found: dprintk("RPC: backchannel req=%p\n", req); +not_found: return req; } @@ -321,18 +320,27 @@ void xprt_free_bc_rqst(struct rpc_rqst *req) */ struct rpc_rqst *xprt_lookup_bc_request(struct rpc_xprt *xprt, __be32 xid) { - struct rpc_rqst *req; - - spin_lock(&xprt->bc_pa_lock); - list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) { - if (req->rq_connect_cookie != xprt->connect_cookie) - continue; - if (req->rq_xid == xid) - goto found; - } - req = xprt_alloc_bc_request(xprt, xid); + struct rpc_rqst *req, *new = NULL; + + do { + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) { + if (req->rq_connect_cookie != xprt->connect_cookie) + continue; + if (req->rq_xid == xid) + goto found; + } + req = xprt_get_bc_request(xprt, xid, new); found: - spin_unlock(&xprt->bc_pa_lock); + spin_unlock(&xprt->bc_pa_lock); + if (new) { + if (req != new) + xprt_free_bc_rqst(new); + break; + } else if (req) + break; + new = xprt_alloc_bc_req(xprt, GFP_KERNEL); + } while (new); return req; } diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index d7ec6132c046..228970e6e52b 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -66,20 +66,19 @@ static void call_decode(struct rpc_task *task); static void call_bind(struct rpc_task *task); static void call_bind_status(struct rpc_task *task); static void call_transmit(struct rpc_task *task); -#if defined(CONFIG_SUNRPC_BACKCHANNEL) -static void call_bc_transmit(struct rpc_task *task); -#endif /* CONFIG_SUNRPC_BACKCHANNEL */ static void call_status(struct rpc_task *task); static void call_transmit_status(struct rpc_task *task); static void call_refresh(struct rpc_task *task); static void call_refreshresult(struct rpc_task *task); -static void call_timeout(struct rpc_task *task); static void call_connect(struct rpc_task *task); static void call_connect_status(struct rpc_task *task); -static __be32 *rpc_encode_header(struct rpc_task *task); -static __be32 *rpc_verify_header(struct rpc_task *task); +static int rpc_encode_header(struct rpc_task *task, + struct xdr_stream *xdr); +static int rpc_decode_header(struct rpc_task *task, + struct xdr_stream *xdr); static int rpc_ping(struct rpc_clnt *clnt); +static void rpc_check_timeout(struct rpc_task *task); static void rpc_register_client(struct rpc_clnt *clnt) { @@ -834,9 +833,6 @@ void rpc_killall_tasks(struct rpc_clnt *clnt) if (!(rovr->tk_flags & RPC_TASK_KILLED)) { rovr->tk_flags |= RPC_TASK_KILLED; rpc_exit(rovr, -EIO); - if (RPC_IS_QUEUED(rovr)) - rpc_wake_up_queued_task(rovr->tk_waitqueue, - rovr); } } spin_unlock(&clnt->cl_lock); @@ -1131,6 +1127,8 @@ rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags, EXPORT_SYMBOL_GPL(rpc_call_async); #if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_encode(struct rpc_task *task); + /** * rpc_run_bc_task - Allocate a new RPC task for backchannel use, then run * rpc_execute against it @@ -1152,7 +1150,7 @@ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req) task = rpc_new_task(&task_setup_data); xprt_init_bc_request(req, task); - task->tk_action = call_bc_transmit; + task->tk_action = call_bc_encode; atomic_inc(&task->tk_count); WARN_ON_ONCE(atomic_read(&task->tk_count) != 2); rpc_execute(task); @@ -1162,6 +1160,29 @@ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req) } #endif /* CONFIG_SUNRPC_BACKCHANNEL */ +/** + * rpc_prepare_reply_pages - Prepare to receive a reply data payload into pages + * @req: RPC request to prepare + * @pages: vector of struct page pointers + * @base: offset in first page where receive should start, in bytes + * @len: expected size of the upper layer data payload, in bytes + * @hdrsize: expected size of upper layer reply header, in XDR words + * + */ +void rpc_prepare_reply_pages(struct rpc_rqst *req, struct page **pages, + unsigned int base, unsigned int len, + unsigned int hdrsize) +{ + /* Subtract one to force an extra word of buffer space for the + * payload's XDR pad to fall into the rcv_buf's tail iovec. + */ + hdrsize += RPC_REPHDRSIZE + req->rq_cred->cr_auth->au_ralign - 1; + + xdr_inline_pages(&req->rq_rcv_buf, hdrsize << 2, pages, base, len); + trace_rpc_reply_pages(req); +} +EXPORT_SYMBOL_GPL(rpc_prepare_reply_pages); + void rpc_call_start(struct rpc_task *task) { @@ -1519,6 +1540,7 @@ call_start(struct rpc_task *task) clnt->cl_stats->rpccnt++; task->tk_action = call_reserve; rpc_task_set_transport(task, clnt); + call_reserve(task); } /* @@ -1532,6 +1554,9 @@ call_reserve(struct rpc_task *task) task->tk_status = 0; task->tk_action = call_reserveresult; xprt_reserve(task); + if (rpc_task_need_resched(task)) + return; + call_reserveresult(task); } static void call_retry_reserve(struct rpc_task *task); @@ -1554,6 +1579,7 @@ call_reserveresult(struct rpc_task *task) if (status >= 0) { if (task->tk_rqstp) { task->tk_action = call_refresh; + call_refresh(task); return; } @@ -1579,6 +1605,7 @@ call_reserveresult(struct rpc_task *task) /* fall through */ case -EAGAIN: /* woken up; retry */ task->tk_action = call_retry_reserve; + call_retry_reserve(task); return; case -EIO: /* probably a shutdown */ break; @@ -1601,6 +1628,9 @@ call_retry_reserve(struct rpc_task *task) task->tk_status = 0; task->tk_action = call_reserveresult; xprt_retry_reserve(task); + if (rpc_task_need_resched(task)) + return; + call_reserveresult(task); } /* @@ -1615,6 +1645,9 @@ call_refresh(struct rpc_task *task) task->tk_status = 0; task->tk_client->cl_stats->rpcauthrefresh++; rpcauth_refreshcred(task); + if (rpc_task_need_resched(task)) + return; + call_refreshresult(task); } /* @@ -1633,6 +1666,7 @@ call_refreshresult(struct rpc_task *task) case 0: if (rpcauth_uptodatecred(task)) { task->tk_action = call_allocate; + call_allocate(task); return; } /* Use rate-limiting and a max number of retries if refresh @@ -1651,6 +1685,7 @@ call_refreshresult(struct rpc_task *task) task->tk_cred_retry--; dprintk("RPC: %5u %s: retry refresh creds\n", task->tk_pid, __func__); + call_refresh(task); return; } dprintk("RPC: %5u %s: refresh creds failed with error %d\n", @@ -1665,7 +1700,7 @@ call_refreshresult(struct rpc_task *task) static void call_allocate(struct rpc_task *task) { - unsigned int slack = task->tk_rqstp->rq_cred->cr_auth->au_cslack; + const struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth; struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; @@ -1676,8 +1711,10 @@ call_allocate(struct rpc_task *task) task->tk_status = 0; task->tk_action = call_encode; - if (req->rq_buffer) + if (req->rq_buffer) { + call_encode(task); return; + } if (proc->p_proc != 0) { BUG_ON(proc->p_arglen == 0); @@ -1690,15 +1727,25 @@ call_allocate(struct rpc_task *task) * and reply headers, and convert both values * to byte sizes. */ - req->rq_callsize = RPC_CALLHDRSIZE + (slack << 1) + proc->p_arglen; + req->rq_callsize = RPC_CALLHDRSIZE + (auth->au_cslack << 1) + + proc->p_arglen; req->rq_callsize <<= 2; - req->rq_rcvsize = RPC_REPHDRSIZE + slack + proc->p_replen; + /* + * Note: the reply buffer must at minimum allocate enough space + * for the 'struct accepted_reply' from RFC5531. + */ + req->rq_rcvsize = RPC_REPHDRSIZE + auth->au_rslack + \ + max_t(size_t, proc->p_replen, 2); req->rq_rcvsize <<= 2; status = xprt->ops->buf_alloc(task); xprt_inject_disconnect(xprt); - if (status == 0) + if (status == 0) { + if (rpc_task_need_resched(task)) + return; + call_encode(task); return; + } if (status != -ENOMEM) { rpc_exit(task, status); return; @@ -1728,10 +1775,7 @@ static void rpc_xdr_encode(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - kxdreproc_t encode; - __be32 *p; - - dprint_status(task); + struct xdr_stream xdr; xdr_buf_init(&req->rq_snd_buf, req->rq_buffer, @@ -1740,18 +1784,13 @@ rpc_xdr_encode(struct rpc_task *task) req->rq_rbuffer, req->rq_rcvsize); - p = rpc_encode_header(task); - if (p == NULL) - return; - - encode = task->tk_msg.rpc_proc->p_encode; - if (encode == NULL) + req->rq_snd_buf.head[0].iov_len = 0; + xdr_init_encode(&xdr, &req->rq_snd_buf, + req->rq_snd_buf.head[0].iov_base, req); + if (rpc_encode_header(task, &xdr)) return; - task->tk_status = rpcauth_wrap_req(task, encode, req, p, - task->tk_msg.rpc_argp); - if (task->tk_status == 0) - xprt_request_prepare(req); + task->tk_status = rpcauth_wrap_req(task, &xdr); } /* @@ -1762,6 +1801,7 @@ call_encode(struct rpc_task *task) { if (!rpc_task_need_encode(task)) goto out; + dprint_status(task); /* Encode here so that rpcsec_gss can use correct sequence number. */ rpc_xdr_encode(task); /* Did the encode result in an error condition? */ @@ -1779,6 +1819,8 @@ call_encode(struct rpc_task *task) rpc_exit(task, task->tk_status); } return; + } else { + xprt_request_prepare(task->tk_rqstp); } /* Add task to reply queue before transmission to avoid races */ @@ -1787,6 +1829,25 @@ call_encode(struct rpc_task *task) xprt_request_enqueue_transmit(task); out: task->tk_action = call_bind; + call_bind(task); +} + +/* + * Helpers to check if the task was already transmitted, and + * to take action when that is the case. + */ +static bool +rpc_task_transmitted(struct rpc_task *task) +{ + return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); +} + +static void +rpc_task_handle_transmitted(struct rpc_task *task) +{ + xprt_end_transmit(task); + task->tk_action = call_transmit_status; + call_transmit_status(task); } /* @@ -1797,14 +1858,25 @@ call_bind(struct rpc_task *task) { struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; - dprint_status(task); + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } - task->tk_action = call_connect; - if (!xprt_bound(xprt)) { - task->tk_action = call_bind_status; - task->tk_timeout = xprt->bind_timeout; - xprt->ops->rpcbind(task); + if (xprt_bound(xprt)) { + task->tk_action = call_connect; + call_connect(task); + return; } + + dprint_status(task); + + task->tk_action = call_bind_status; + if (!xprt_prepare_transmit(task)) + return; + + task->tk_timeout = xprt->bind_timeout; + xprt->ops->rpcbind(task); } /* @@ -1815,10 +1887,16 @@ call_bind_status(struct rpc_task *task) { int status = -EIO; + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + if (task->tk_status >= 0) { dprint_status(task); task->tk_status = 0; task->tk_action = call_connect; + call_connect(task); return; } @@ -1841,6 +1919,8 @@ call_bind_status(struct rpc_task *task) task->tk_rebind_retry--; rpc_delay(task, 3*HZ); goto retry_timeout; + case -EAGAIN: + goto retry_timeout; case -ETIMEDOUT: dprintk("RPC: %5u rpcbind request timed out\n", task->tk_pid); @@ -1882,7 +1962,8 @@ call_bind_status(struct rpc_task *task) retry_timeout: task->tk_status = 0; - task->tk_action = call_timeout; + task->tk_action = call_bind; + rpc_check_timeout(task); } /* @@ -1893,21 +1974,31 @@ call_connect(struct rpc_task *task) { struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + if (xprt_connected(xprt)) { + task->tk_action = call_transmit; + call_transmit(task); + return; + } + dprintk("RPC: %5u call_connect xprt %p %s connected\n", task->tk_pid, xprt, (xprt_connected(xprt) ? "is" : "is not")); - task->tk_action = call_transmit; - if (!xprt_connected(xprt)) { - task->tk_action = call_connect_status; - if (task->tk_status < 0) - return; - if (task->tk_flags & RPC_TASK_NOCONNECT) { - rpc_exit(task, -ENOTCONN); - return; - } - xprt_connect(task); + task->tk_action = call_connect_status; + if (task->tk_status < 0) + return; + if (task->tk_flags & RPC_TASK_NOCONNECT) { + rpc_exit(task, -ENOTCONN); + return; } + if (!xprt_prepare_transmit(task)) + return; + xprt_connect(task); } /* @@ -1919,10 +2010,8 @@ call_connect_status(struct rpc_task *task) struct rpc_clnt *clnt = task->tk_client; int status = task->tk_status; - /* Check if the task was already transmitted */ - if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { - xprt_end_transmit(task); - task->tk_action = call_transmit_status; + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); return; } @@ -1937,8 +2026,7 @@ call_connect_status(struct rpc_task *task) break; if (clnt->cl_autobind) { rpc_force_rebind(clnt); - task->tk_action = call_bind; - return; + goto out_retry; } /* fall through */ case -ECONNRESET: @@ -1958,16 +2046,20 @@ call_connect_status(struct rpc_task *task) /* fall through */ case -ENOTCONN: case -EAGAIN: - /* Check for timeouts before looping back to call_bind */ case -ETIMEDOUT: - task->tk_action = call_timeout; - return; + goto out_retry; case 0: clnt->cl_stats->netreconn++; task->tk_action = call_transmit; + call_transmit(task); return; } rpc_exit(task, status); + return; +out_retry: + /* Check for timeouts before looping back to call_bind */ + task->tk_action = call_bind; + rpc_check_timeout(task); } /* @@ -1976,16 +2068,28 @@ call_connect_status(struct rpc_task *task) static void call_transmit(struct rpc_task *task) { + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + dprint_status(task); + task->tk_action = call_transmit_status; + if (!xprt_prepare_transmit(task)) + return; task->tk_status = 0; if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { - if (!xprt_prepare_transmit(task)) + if (!xprt_connected(task->tk_xprt)) { + task->tk_status = -ENOTCONN; return; + } xprt_transmit(task); } - task->tk_action = call_transmit_status; xprt_end_transmit(task); + if (rpc_task_need_resched(task)) + return; + call_transmit_status(task); } /* @@ -2000,8 +2104,12 @@ call_transmit_status(struct rpc_task *task) * Common case: success. Force the compiler to put this * test first. */ - if (task->tk_status == 0) { - xprt_request_wait_receive(task); + if (rpc_task_transmitted(task)) { + if (task->tk_status == 0) + xprt_request_wait_receive(task); + if (rpc_task_need_resched(task)) + return; + call_status(task); return; } @@ -2038,7 +2146,7 @@ call_transmit_status(struct rpc_task *task) trace_xprt_ping(task->tk_xprt, task->tk_status); rpc_exit(task, task->tk_status); - break; + return; } /* fall through */ case -ECONNRESET: @@ -2046,11 +2154,25 @@ call_transmit_status(struct rpc_task *task) case -EADDRINUSE: case -ENOTCONN: case -EPIPE: + task->tk_action = call_bind; + task->tk_status = 0; break; } + rpc_check_timeout(task); } #if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_transmit(struct rpc_task *task); +static void call_bc_transmit_status(struct rpc_task *task); + +static void +call_bc_encode(struct rpc_task *task) +{ + xprt_request_enqueue_transmit(task); + task->tk_action = call_bc_transmit; + call_bc_transmit(task); +} + /* * 5b. Send the backchannel RPC reply. On error, drop the reply. In * addition, disconnect on connectivity errors. @@ -2058,26 +2180,23 @@ call_transmit_status(struct rpc_task *task) static void call_bc_transmit(struct rpc_task *task) { - struct rpc_rqst *req = task->tk_rqstp; - - if (rpc_task_need_encode(task)) - xprt_request_enqueue_transmit(task); - if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) - goto out_wakeup; - - if (!xprt_prepare_transmit(task)) - goto out_retry; - - if (task->tk_status < 0) { - printk(KERN_NOTICE "RPC: Could not send backchannel reply " - "error: %d\n", task->tk_status); - goto out_done; + task->tk_action = call_bc_transmit_status; + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { + if (!xprt_prepare_transmit(task)) + return; + task->tk_status = 0; + xprt_transmit(task); } + xprt_end_transmit(task); +} - xprt_transmit(task); +static void +call_bc_transmit_status(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; - xprt_end_transmit(task); dprint_status(task); + switch (task->tk_status) { case 0: /* Success */ @@ -2091,8 +2210,14 @@ call_bc_transmit(struct rpc_task *task) case -ENOTCONN: case -EPIPE: break; + case -ENOBUFS: + rpc_delay(task, HZ>>2); + /* fall through */ + case -EBADSLT: case -EAGAIN: - goto out_retry; + task->tk_status = 0; + task->tk_action = call_bc_transmit; + return; case -ETIMEDOUT: /* * Problem reaching the server. Disconnect and let the @@ -2111,18 +2236,11 @@ call_bc_transmit(struct rpc_task *task) * We were unable to reply and will have to drop the * request. The server should reconnect and retransmit. */ - WARN_ON_ONCE(task->tk_status == -EAGAIN); printk(KERN_NOTICE "RPC: Could not send backchannel reply " "error: %d\n", task->tk_status); break; } -out_wakeup: - rpc_wake_up_queued_task(&req->rq_xprt->pending, task); -out_done: task->tk_action = rpc_exit_task; - return; -out_retry: - task->tk_status = 0; } #endif /* CONFIG_SUNRPC_BACKCHANNEL */ @@ -2143,6 +2261,7 @@ call_status(struct rpc_task *task) status = task->tk_status; if (status >= 0) { task->tk_action = call_decode; + call_decode(task); return; } @@ -2154,10 +2273,8 @@ call_status(struct rpc_task *task) case -EHOSTUNREACH: case -ENETUNREACH: case -EPERM: - if (RPC_IS_SOFTCONN(task)) { - rpc_exit(task, status); - break; - } + if (RPC_IS_SOFTCONN(task)) + goto out_exit; /* * Delay any retries for 3 seconds, then handle as if it * were a timeout. @@ -2165,7 +2282,6 @@ call_status(struct rpc_task *task) rpc_delay(task, 3*HZ); /* fall through */ case -ETIMEDOUT: - task->tk_action = call_timeout; break; case -ECONNREFUSED: case -ECONNRESET: @@ -2178,34 +2294,30 @@ call_status(struct rpc_task *task) case -EPIPE: case -ENOTCONN: case -EAGAIN: - task->tk_action = call_encode; break; case -EIO: /* shutdown or soft timeout */ - rpc_exit(task, status); - break; + goto out_exit; default: if (clnt->cl_chatty) printk("%s: RPC call returned error %d\n", clnt->cl_program->name, -status); - rpc_exit(task, status); + goto out_exit; } + task->tk_action = call_encode; + rpc_check_timeout(task); + return; +out_exit: + rpc_exit(task, status); } -/* - * 6a. Handle RPC timeout - * We do not release the request slot, so we keep using the - * same XID for all retransmits. - */ static void -call_timeout(struct rpc_task *task) +rpc_check_timeout(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; - if (xprt_adjust_timeout(task->tk_rqstp) == 0) { - dprintk("RPC: %5u call_timeout (minor)\n", task->tk_pid); - goto retry; - } + if (xprt_adjust_timeout(task->tk_rqstp) == 0) + return; dprintk("RPC: %5u call_timeout (major)\n", task->tk_pid); task->tk_timeouts++; @@ -2241,10 +2353,6 @@ call_timeout(struct rpc_task *task) * event? RFC2203 requires the server to drop all such requests. */ rpcauth_invalcred(task); - -retry: - task->tk_action = call_encode; - task->tk_status = 0; } /* @@ -2255,12 +2363,11 @@ call_decode(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; struct rpc_rqst *req = task->tk_rqstp; - kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; - __be32 *p; + struct xdr_stream xdr; dprint_status(task); - if (!decode) { + if (!task->tk_msg.rpc_proc->p_decode) { task->tk_action = rpc_exit_task; return; } @@ -2285,223 +2392,186 @@ call_decode(struct rpc_task *task) WARN_ON(memcmp(&req->rq_rcv_buf, &req->rq_private_buf, sizeof(req->rq_rcv_buf)) != 0); - if (req->rq_rcv_buf.len < 12) { - if (!RPC_IS_SOFT(task)) { - task->tk_action = call_encode; - goto out_retry; - } - dprintk("RPC: %s: too small RPC reply size (%d bytes)\n", - clnt->cl_program->name, task->tk_status); - task->tk_action = call_timeout; - goto out_retry; - } - - p = rpc_verify_header(task); - if (IS_ERR(p)) { - if (p == ERR_PTR(-EAGAIN)) - goto out_retry; + xdr_init_decode(&xdr, &req->rq_rcv_buf, + req->rq_rcv_buf.head[0].iov_base, req); + switch (rpc_decode_header(task, &xdr)) { + case 0: + task->tk_action = rpc_exit_task; + task->tk_status = rpcauth_unwrap_resp(task, &xdr); + dprintk("RPC: %5u %s result %d\n", + task->tk_pid, __func__, task->tk_status); return; - } - task->tk_action = rpc_exit_task; - - task->tk_status = rpcauth_unwrap_resp(task, decode, req, p, - task->tk_msg.rpc_resp); - - dprintk("RPC: %5u call_decode result %d\n", task->tk_pid, - task->tk_status); - return; -out_retry: - task->tk_status = 0; - /* Note: rpc_verify_header() may have freed the RPC slot */ - if (task->tk_rqstp == req) { - xdr_free_bvec(&req->rq_rcv_buf); - req->rq_reply_bytes_recvd = req->rq_rcv_buf.len = 0; - if (task->tk_client->cl_discrtry) - xprt_conditional_disconnect(req->rq_xprt, - req->rq_connect_cookie); + case -EAGAIN: + task->tk_status = 0; + /* Note: rpc_decode_header() may have freed the RPC slot */ + if (task->tk_rqstp == req) { + xdr_free_bvec(&req->rq_rcv_buf); + req->rq_reply_bytes_recvd = 0; + req->rq_rcv_buf.len = 0; + if (task->tk_client->cl_discrtry) + xprt_conditional_disconnect(req->rq_xprt, + req->rq_connect_cookie); + } + task->tk_action = call_encode; + rpc_check_timeout(task); } } -static __be32 * -rpc_encode_header(struct rpc_task *task) +static int +rpc_encode_header(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_clnt *clnt = task->tk_client; struct rpc_rqst *req = task->tk_rqstp; - __be32 *p = req->rq_svec[0].iov_base; - - /* FIXME: check buffer size? */ - - p = xprt_skip_transport_header(req->rq_xprt, p); - *p++ = req->rq_xid; /* XID */ - *p++ = htonl(RPC_CALL); /* CALL */ - *p++ = htonl(RPC_VERSION); /* RPC version */ - *p++ = htonl(clnt->cl_prog); /* program number */ - *p++ = htonl(clnt->cl_vers); /* program version */ - *p++ = htonl(task->tk_msg.rpc_proc->p_proc); /* procedure */ - p = rpcauth_marshcred(task, p); - if (p) - req->rq_slen = xdr_adjust_iovec(&req->rq_svec[0], p); - return p; + __be32 *p; + int error; + + error = -EMSGSIZE; + p = xdr_reserve_space(xdr, RPC_CALLHDRSIZE << 2); + if (!p) + goto out_fail; + *p++ = req->rq_xid; + *p++ = rpc_call; + *p++ = cpu_to_be32(RPC_VERSION); + *p++ = cpu_to_be32(clnt->cl_prog); + *p++ = cpu_to_be32(clnt->cl_vers); + *p = cpu_to_be32(task->tk_msg.rpc_proc->p_proc); + + error = rpcauth_marshcred(task, xdr); + if (error < 0) + goto out_fail; + return 0; +out_fail: + trace_rpc_bad_callhdr(task); + rpc_exit(task, error); + return error; } -static __be32 * -rpc_verify_header(struct rpc_task *task) +static noinline int +rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr) { struct rpc_clnt *clnt = task->tk_client; - struct kvec *iov = &task->tk_rqstp->rq_rcv_buf.head[0]; - int len = task->tk_rqstp->rq_rcv_buf.len >> 2; - __be32 *p = iov->iov_base; - u32 n; - int error = -EACCES; - - if ((task->tk_rqstp->rq_rcv_buf.len & 3) != 0) { - /* RFC-1014 says that the representation of XDR data must be a - * multiple of four bytes - * - if it isn't pointer subtraction in the NFS client may give - * undefined results - */ - dprintk("RPC: %5u %s: XDR representation not a multiple of" - " 4 bytes: 0x%x\n", task->tk_pid, __func__, - task->tk_rqstp->rq_rcv_buf.len); - error = -EIO; - goto out_err; - } - if ((len -= 3) < 0) - goto out_overflow; - - p += 1; /* skip XID */ - if ((n = ntohl(*p++)) != RPC_REPLY) { - dprintk("RPC: %5u %s: not an RPC reply: %x\n", - task->tk_pid, __func__, n); - error = -EIO; - goto out_garbage; - } + int error; + __be32 *p; - if ((n = ntohl(*p++)) != RPC_MSG_ACCEPTED) { - if (--len < 0) - goto out_overflow; - switch ((n = ntohl(*p++))) { - case RPC_AUTH_ERROR: - break; - case RPC_MISMATCH: - dprintk("RPC: %5u %s: RPC call version mismatch!\n", - task->tk_pid, __func__); - error = -EPROTONOSUPPORT; - goto out_err; - default: - dprintk("RPC: %5u %s: RPC call rejected, " - "unknown error: %x\n", - task->tk_pid, __func__, n); - error = -EIO; - goto out_err; - } - if (--len < 0) - goto out_overflow; - switch ((n = ntohl(*p++))) { - case RPC_AUTH_REJECTEDCRED: - case RPC_AUTH_REJECTEDVERF: - case RPCSEC_GSS_CREDPROBLEM: - case RPCSEC_GSS_CTXPROBLEM: - if (!task->tk_cred_retry) - break; - task->tk_cred_retry--; - dprintk("RPC: %5u %s: retry stale creds\n", - task->tk_pid, __func__); - rpcauth_invalcred(task); - /* Ensure we obtain a new XID! */ - xprt_release(task); - task->tk_action = call_reserve; - goto out_retry; - case RPC_AUTH_BADCRED: - case RPC_AUTH_BADVERF: - /* possibly garbled cred/verf? */ - if (!task->tk_garb_retry) - break; - task->tk_garb_retry--; - dprintk("RPC: %5u %s: retry garbled creds\n", - task->tk_pid, __func__); - task->tk_action = call_encode; - goto out_retry; - case RPC_AUTH_TOOWEAK: - printk(KERN_NOTICE "RPC: server %s requires stronger " - "authentication.\n", - task->tk_xprt->servername); - break; - default: - dprintk("RPC: %5u %s: unknown auth error: %x\n", - task->tk_pid, __func__, n); - error = -EIO; - } - dprintk("RPC: %5u %s: call rejected %d\n", - task->tk_pid, __func__, n); - goto out_err; - } - p = rpcauth_checkverf(task, p); - if (IS_ERR(p)) { - error = PTR_ERR(p); - dprintk("RPC: %5u %s: auth check failed with %d\n", - task->tk_pid, __func__, error); - goto out_garbage; /* bad verifier, retry */ - } - len = p - (__be32 *)iov->iov_base - 1; - if (len < 0) - goto out_overflow; - switch ((n = ntohl(*p++))) { - case RPC_SUCCESS: - return p; - case RPC_PROG_UNAVAIL: - dprintk("RPC: %5u %s: program %u is unsupported " - "by server %s\n", task->tk_pid, __func__, - (unsigned int)clnt->cl_prog, - task->tk_xprt->servername); + /* RFC-1014 says that the representation of XDR data must be a + * multiple of four bytes + * - if it isn't pointer subtraction in the NFS client may give + * undefined results + */ + if (task->tk_rqstp->rq_rcv_buf.len & 3) + goto out_unparsable; + + p = xdr_inline_decode(xdr, 3 * sizeof(*p)); + if (!p) + goto out_unparsable; + p++; /* skip XID */ + if (*p++ != rpc_reply) + goto out_unparsable; + if (*p++ != rpc_msg_accepted) + goto out_msg_denied; + + error = rpcauth_checkverf(task, xdr); + if (error) + goto out_verifier; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p) { + case rpc_success: + return 0; + case rpc_prog_unavail: + trace_rpc__prog_unavail(task); error = -EPFNOSUPPORT; goto out_err; - case RPC_PROG_MISMATCH: - dprintk("RPC: %5u %s: program %u, version %u unsupported " - "by server %s\n", task->tk_pid, __func__, - (unsigned int)clnt->cl_prog, - (unsigned int)clnt->cl_vers, - task->tk_xprt->servername); + case rpc_prog_mismatch: + trace_rpc__prog_mismatch(task); error = -EPROTONOSUPPORT; goto out_err; - case RPC_PROC_UNAVAIL: - dprintk("RPC: %5u %s: proc %s unsupported by program %u, " - "version %u on server %s\n", - task->tk_pid, __func__, - rpc_proc_name(task), - clnt->cl_prog, clnt->cl_vers, - task->tk_xprt->servername); + case rpc_proc_unavail: + trace_rpc__proc_unavail(task); error = -EOPNOTSUPP; goto out_err; - case RPC_GARBAGE_ARGS: - dprintk("RPC: %5u %s: server saw garbage\n", - task->tk_pid, __func__); - break; /* retry */ + case rpc_garbage_args: + case rpc_system_err: + trace_rpc__garbage_args(task); + error = -EIO; + break; default: - dprintk("RPC: %5u %s: server accept status: %x\n", - task->tk_pid, __func__, n); - /* Also retry */ + goto out_unparsable; } out_garbage: clnt->cl_stats->rpcgarbage++; if (task->tk_garb_retry) { task->tk_garb_retry--; - dprintk("RPC: %5u %s: retrying\n", - task->tk_pid, __func__); task->tk_action = call_encode; -out_retry: - return ERR_PTR(-EAGAIN); + return -EAGAIN; } out_err: rpc_exit(task, error); - dprintk("RPC: %5u %s: call failed with error %d\n", task->tk_pid, - __func__, error); - return ERR_PTR(error); -out_overflow: - dprintk("RPC: %5u %s: server reply was truncated.\n", task->tk_pid, - __func__); + return error; + +out_unparsable: + trace_rpc__unparsable(task); + error = -EIO; + goto out_garbage; + +out_verifier: + trace_rpc_bad_verifier(task); goto out_garbage; + +out_msg_denied: + error = -EACCES; + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p++) { + case rpc_auth_error: + break; + case rpc_mismatch: + trace_rpc__mismatch(task); + error = -EPROTONOSUPPORT; + goto out_err; + default: + goto out_unparsable; + } + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p++) { + case rpc_autherr_rejectedcred: + case rpc_autherr_rejectedverf: + case rpcsec_gsserr_credproblem: + case rpcsec_gsserr_ctxproblem: + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + trace_rpc__stale_creds(task); + rpcauth_invalcred(task); + /* Ensure we obtain a new XID! */ + xprt_release(task); + task->tk_action = call_reserve; + return -EAGAIN; + case rpc_autherr_badcred: + case rpc_autherr_badverf: + /* possibly garbled cred/verf? */ + if (!task->tk_garb_retry) + break; + task->tk_garb_retry--; + trace_rpc__bad_creds(task); + task->tk_action = call_encode; + return -EAGAIN; + case rpc_autherr_tooweak: + trace_rpc__auth_tooweak(task); + pr_warn("RPC: server %s requires stronger authentication.\n", + task->tk_xprt->servername); + break; + default: + goto out_unparsable; + } + goto out_err; } static void rpcproc_encode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, diff --git a/net/sunrpc/debugfs.c b/net/sunrpc/debugfs.c index 45a033329cd4..19bb356230ed 100644 --- a/net/sunrpc/debugfs.c +++ b/net/sunrpc/debugfs.c @@ -146,7 +146,7 @@ rpc_clnt_debugfs_register(struct rpc_clnt *clnt) rcu_read_lock(); xprt = rcu_dereference(clnt->cl_xprt); /* no "debugfs" dentry? Don't bother with the symlink. */ - if (!xprt->debugfs) { + if (IS_ERR_OR_NULL(xprt->debugfs)) { rcu_read_unlock(); return; } diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c index adc3c40cc733..28956c70100a 100644 --- a/net/sunrpc/sched.c +++ b/net/sunrpc/sched.c @@ -19,6 +19,7 @@ #include <linux/spinlock.h> #include <linux/mutex.h> #include <linux/freezer.h> +#include <linux/sched/mm.h> #include <linux/sunrpc/clnt.h> @@ -784,8 +785,7 @@ void rpc_exit(struct rpc_task *task, int status) { task->tk_status = status; task->tk_action = rpc_exit_task; - if (RPC_IS_QUEUED(task)) - rpc_wake_up_queued_task(task->tk_waitqueue, task); + rpc_wake_up_queued_task(task->tk_waitqueue, task); } EXPORT_SYMBOL_GPL(rpc_exit); @@ -902,7 +902,10 @@ void rpc_execute(struct rpc_task *task) static void rpc_async_schedule(struct work_struct *work) { + unsigned int pflags = memalloc_nofs_save(); + __rpc_execute(container_of(work, struct rpc_task, u.tk_work)); + memalloc_nofs_restore(pflags); } /** @@ -921,16 +924,13 @@ static void rpc_async_schedule(struct work_struct *work) * Most requests are 'small' (under 2KiB) and can be serviced from a * mempool, ensuring that NFS reads and writes can always proceed, * and that there is good locality of reference for these buffers. - * - * In order to avoid memory starvation triggering more writebacks of - * NFS requests, we avoid using GFP_KERNEL. */ int rpc_malloc(struct rpc_task *task) { struct rpc_rqst *rqst = task->tk_rqstp; size_t size = rqst->rq_callsize + rqst->rq_rcvsize; struct rpc_buffer *buf; - gfp_t gfp = GFP_NOIO | __GFP_NOWARN; + gfp_t gfp = GFP_NOFS; if (RPC_IS_SWAPPER(task)) gfp = __GFP_MEMALLOC | GFP_NOWAIT | __GFP_NOWARN; @@ -1011,7 +1011,7 @@ static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *ta static struct rpc_task * rpc_alloc_task(void) { - return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOIO); + return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOFS); } /* @@ -1067,7 +1067,10 @@ static void rpc_free_task(struct rpc_task *task) static void rpc_async_release(struct work_struct *work) { + unsigned int pflags = memalloc_nofs_save(); + rpc_free_task(container_of(work, struct rpc_task, u.tk_work)); + memalloc_nofs_restore(pflags); } static void rpc_release_resources_task(struct rpc_task *task) diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c index e87ddb9f7feb..dbd19697ee38 100644 --- a/net/sunrpc/svc.c +++ b/net/sunrpc/svc.c @@ -1145,17 +1145,6 @@ static __printf(2,3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, .. #endif /* - * Setup response header for TCP, it has a 4B record length field. - */ -static void svc_tcp_prep_reply_hdr(struct svc_rqst *rqstp) -{ - struct kvec *resv = &rqstp->rq_res.head[0]; - - /* tcp needs a space for the record length... */ - svc_putnl(resv, 0); -} - -/* * Common routine for processing the RPC request. */ static int @@ -1182,10 +1171,6 @@ svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags); clear_bit(RQ_DROPME, &rqstp->rq_flags); - /* Setup reply header */ - if (rqstp->rq_prot == IPPROTO_TCP) - svc_tcp_prep_reply_hdr(rqstp); - svc_putu32(resv, rqstp->rq_xid); vers = svc_getnl(argv); @@ -1443,6 +1428,10 @@ svc_process(struct svc_rqst *rqstp) goto out_drop; } + /* Reserve space for the record marker */ + if (rqstp->rq_prot == IPPROTO_TCP) + svc_putnl(resv, 0); + /* Returns 1 for send, 0 for drop */ if (likely(svc_process_common(rqstp, argv, resv))) return svc_send(rqstp); diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index 4eb8fbf2508d..61530b1b7754 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -357,15 +357,29 @@ static void svc_xprt_release_slot(struct svc_rqst *rqstp) struct svc_xprt *xprt = rqstp->rq_xprt; if (test_and_clear_bit(RQ_DATA, &rqstp->rq_flags)) { atomic_dec(&xprt->xpt_nr_rqsts); + smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */ svc_xprt_enqueue(xprt); } } -static bool svc_xprt_has_something_to_do(struct svc_xprt *xprt) +static bool svc_xprt_ready(struct svc_xprt *xprt) { - if (xprt->xpt_flags & ((1<<XPT_CONN)|(1<<XPT_CLOSE))) + unsigned long xpt_flags; + + /* + * If another cpu has recently updated xpt_flags, + * sk_sock->flags, xpt_reserved, or xpt_nr_rqsts, we need to + * know about it; otherwise it's possible that both that cpu and + * this one could call svc_xprt_enqueue() without either + * svc_xprt_enqueue() recognizing that the conditions below + * are satisfied, and we could stall indefinitely: + */ + smp_rmb(); + xpt_flags = READ_ONCE(xprt->xpt_flags); + + if (xpt_flags & (BIT(XPT_CONN) | BIT(XPT_CLOSE))) return true; - if (xprt->xpt_flags & ((1<<XPT_DATA)|(1<<XPT_DEFERRED))) { + if (xpt_flags & (BIT(XPT_DATA) | BIT(XPT_DEFERRED))) { if (xprt->xpt_ops->xpo_has_wspace(xprt) && svc_xprt_slots_in_range(xprt)) return true; @@ -381,7 +395,7 @@ void svc_xprt_do_enqueue(struct svc_xprt *xprt) struct svc_rqst *rqstp = NULL; int cpu; - if (!svc_xprt_has_something_to_do(xprt)) + if (!svc_xprt_ready(xprt)) return; /* Mark transport as busy. It will remain in this state until @@ -475,7 +489,7 @@ void svc_reserve(struct svc_rqst *rqstp, int space) if (xprt && space < rqstp->rq_reserved) { atomic_sub((rqstp->rq_reserved - space), &xprt->xpt_reserved); rqstp->rq_reserved = space; - + smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */ svc_xprt_enqueue(xprt); } } diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index a6a060925e5d..43590a968b73 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -349,12 +349,16 @@ static ssize_t svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, /* * Set socket snd and rcv buffer lengths */ -static void svc_sock_setbufsize(struct socket *sock, unsigned int snd, - unsigned int rcv) +static void svc_sock_setbufsize(struct svc_sock *svsk, unsigned int nreqs) { + unsigned int max_mesg = svsk->sk_xprt.xpt_server->sv_max_mesg; + struct socket *sock = svsk->sk_sock; + + nreqs = min(nreqs, INT_MAX / 2 / max_mesg); + lock_sock(sock->sk); - sock->sk->sk_sndbuf = snd * 2; - sock->sk->sk_rcvbuf = rcv * 2; + sock->sk->sk_sndbuf = nreqs * max_mesg * 2; + sock->sk->sk_rcvbuf = nreqs * max_mesg * 2; sock->sk->sk_write_space(sock->sk); release_sock(sock->sk); } @@ -516,9 +520,7 @@ static int svc_udp_recvfrom(struct svc_rqst *rqstp) * provides an upper bound on the number of threads * which will access the socket. */ - svc_sock_setbufsize(svsk->sk_sock, - (serv->sv_nrthreads+3) * serv->sv_max_mesg, - (serv->sv_nrthreads+3) * serv->sv_max_mesg); + svc_sock_setbufsize(svsk, serv->sv_nrthreads + 3); clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); skb = NULL; @@ -681,9 +683,7 @@ static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) * receive and respond to one request. * svc_udp_recvfrom will re-adjust if necessary */ - svc_sock_setbufsize(svsk->sk_sock, - 3 * svsk->sk_xprt.xpt_server->sv_max_mesg, - 3 * svsk->sk_xprt.xpt_server->sv_max_mesg); + svc_sock_setbufsize(svsk, 3); /* data might have come in before data_ready set up */ set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c index f302c6eb8779..aa8177ddcbda 100644 --- a/net/sunrpc/xdr.c +++ b/net/sunrpc/xdr.c @@ -16,6 +16,7 @@ #include <linux/sunrpc/xdr.h> #include <linux/sunrpc/msg_prot.h> #include <linux/bvec.h> +#include <trace/events/sunrpc.h> /* * XDR functions for basic NFS types @@ -162,6 +163,15 @@ xdr_free_bvec(struct xdr_buf *buf) buf->bvec = NULL; } +/** + * xdr_inline_pages - Prepare receive buffer for a large reply + * @xdr: xdr_buf into which reply will be placed + * @offset: expected offset where data payload will start, in bytes + * @pages: vector of struct page pointers + * @base: offset in first page where receive should start, in bytes + * @len: expected size of the upper layer data payload, in bytes + * + */ void xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, struct page **pages, unsigned int base, unsigned int len) @@ -179,6 +189,8 @@ xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, tail->iov_base = buf + offset; tail->iov_len = buflen - offset; + if ((xdr->page_len & 3) == 0) + tail->iov_len -= sizeof(__be32); xdr->buflen += len; } @@ -346,13 +358,15 @@ EXPORT_SYMBOL_GPL(_copy_from_pages); * 'len' bytes. The extra data is not lost, but is instead * moved into the inlined pages and/or the tail. */ -static void +static unsigned int xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) { struct kvec *head, *tail; size_t copy, offs; unsigned int pglen = buf->page_len; + unsigned int result; + result = 0; tail = buf->tail; head = buf->head; @@ -366,6 +380,7 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) copy = tail->iov_len - len; memmove((char *)tail->iov_base + len, tail->iov_base, copy); + result += copy; } /* Copy from the inlined pages into the tail */ copy = len; @@ -376,11 +391,13 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) copy = 0; else if (copy > tail->iov_len - offs) copy = tail->iov_len - offs; - if (copy != 0) + if (copy != 0) { _copy_from_pages((char *)tail->iov_base + offs, buf->pages, buf->page_base + pglen + offs - len, copy); + result += copy; + } /* Do we also need to copy data from the head into the tail ? */ if (len > pglen) { offs = copy = len - pglen; @@ -390,6 +407,7 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) (char *)head->iov_base + head->iov_len - offs, copy); + result += copy; } } /* Now handle pages */ @@ -405,12 +423,15 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) _copy_to_pages(buf->pages, buf->page_base, (char *)head->iov_base + head->iov_len - len, copy); + result += copy; } head->iov_len -= len; buf->buflen -= len; /* Have we truncated the message? */ if (buf->len > buf->buflen) buf->len = buf->buflen; + + return result; } /** @@ -422,14 +443,16 @@ xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) * 'len' bytes. The extra data is not lost, but is instead * moved into the tail. */ -static void +static unsigned int xdr_shrink_pagelen(struct xdr_buf *buf, size_t len) { struct kvec *tail; size_t copy; unsigned int pglen = buf->page_len; unsigned int tailbuf_len; + unsigned int result; + result = 0; tail = buf->tail; BUG_ON (len > pglen); @@ -447,18 +470,22 @@ xdr_shrink_pagelen(struct xdr_buf *buf, size_t len) if (tail->iov_len > len) { char *p = (char *)tail->iov_base + len; memmove(p, tail->iov_base, tail->iov_len - len); + result += tail->iov_len - len; } else copy = tail->iov_len; /* Copy from the inlined pages into the tail */ _copy_from_pages((char *)tail->iov_base, buf->pages, buf->page_base + pglen - len, copy); + result += copy; } buf->page_len -= len; buf->buflen -= len; /* Have we truncated the message? */ if (buf->len > buf->buflen) buf->len = buf->buflen; + + return result; } void @@ -483,6 +510,7 @@ EXPORT_SYMBOL_GPL(xdr_stream_pos); * @xdr: pointer to xdr_stream struct * @buf: pointer to XDR buffer in which to encode data * @p: current pointer inside XDR buffer + * @rqst: pointer to controlling rpc_rqst, for debugging * * Note: at the moment the RPC client only passes the length of our * scratch buffer in the xdr_buf's header kvec. Previously this @@ -491,7 +519,8 @@ EXPORT_SYMBOL_GPL(xdr_stream_pos); * of the buffer length, and takes care of adjusting the kvec * length for us. */ -void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p, + struct rpc_rqst *rqst) { struct kvec *iov = buf->head; int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len; @@ -513,6 +542,7 @@ void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) buf->len += len; iov->iov_len += len; } + xdr->rqst = rqst; } EXPORT_SYMBOL_GPL(xdr_init_encode); @@ -551,9 +581,9 @@ static __be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr, int frag1bytes, frag2bytes; if (nbytes > PAGE_SIZE) - return NULL; /* Bigger buffers require special handling */ + goto out_overflow; /* Bigger buffers require special handling */ if (xdr->buf->len + nbytes > xdr->buf->buflen) - return NULL; /* Sorry, we're totally out of space */ + goto out_overflow; /* Sorry, we're totally out of space */ frag1bytes = (xdr->end - xdr->p) << 2; frag2bytes = nbytes - frag1bytes; if (xdr->iov) @@ -582,6 +612,9 @@ static __be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr, xdr->buf->page_len += frag2bytes; xdr->buf->len += nbytes; return p; +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; } /** @@ -819,8 +852,10 @@ static bool xdr_set_next_buffer(struct xdr_stream *xdr) * @xdr: pointer to xdr_stream struct * @buf: pointer to XDR buffer from which to decode data * @p: current pointer inside XDR buffer + * @rqst: pointer to controlling rpc_rqst, for debugging */ -void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p, + struct rpc_rqst *rqst) { xdr->buf = buf; xdr->scratch.iov_base = NULL; @@ -836,6 +871,7 @@ void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) xdr->nwords -= p - xdr->p; xdr->p = p; } + xdr->rqst = rqst; } EXPORT_SYMBOL_GPL(xdr_init_decode); @@ -854,7 +890,7 @@ void xdr_init_decode_pages(struct xdr_stream *xdr, struct xdr_buf *buf, buf->page_len = len; buf->buflen = len; buf->len = len; - xdr_init_decode(xdr, buf, NULL); + xdr_init_decode(xdr, buf, NULL, NULL); } EXPORT_SYMBOL_GPL(xdr_init_decode_pages); @@ -896,20 +932,23 @@ static __be32 *xdr_copy_to_scratch(struct xdr_stream *xdr, size_t nbytes) size_t cplen = (char *)xdr->end - (char *)xdr->p; if (nbytes > xdr->scratch.iov_len) - return NULL; + goto out_overflow; p = __xdr_inline_decode(xdr, cplen); if (p == NULL) return NULL; memcpy(cpdest, p, cplen); + if (!xdr_set_next_buffer(xdr)) + goto out_overflow; cpdest += cplen; nbytes -= cplen; - if (!xdr_set_next_buffer(xdr)) - return NULL; p = __xdr_inline_decode(xdr, nbytes); if (p == NULL) return NULL; memcpy(cpdest, p, nbytes); return xdr->scratch.iov_base; +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; } /** @@ -926,14 +965,17 @@ __be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) { __be32 *p; - if (nbytes == 0) + if (unlikely(nbytes == 0)) return xdr->p; if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) - return NULL; + goto out_overflow; p = __xdr_inline_decode(xdr, nbytes); if (p != NULL) return p; return xdr_copy_to_scratch(xdr, nbytes); +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; } EXPORT_SYMBOL_GPL(xdr_inline_decode); @@ -943,13 +985,17 @@ static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) struct kvec *iov; unsigned int nwords = XDR_QUADLEN(len); unsigned int cur = xdr_stream_pos(xdr); + unsigned int copied, offset; if (xdr->nwords == 0) return 0; + /* Realign pages to current pointer position */ - iov = buf->head; + iov = buf->head; if (iov->iov_len > cur) { - xdr_shrink_bufhead(buf, iov->iov_len - cur); + offset = iov->iov_len - cur; + copied = xdr_shrink_bufhead(buf, offset); + trace_rpc_xdr_alignment(xdr, offset, copied); xdr->nwords = XDR_QUADLEN(buf->len - cur); } @@ -961,7 +1007,9 @@ static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) len = buf->page_len; else if (nwords < xdr->nwords) { /* Truncate page data and move it into the tail */ - xdr_shrink_pagelen(buf, buf->page_len - len); + offset = buf->page_len - len; + copied = xdr_shrink_pagelen(buf, offset); + trace_rpc_xdr_alignment(xdr, offset, copied); xdr->nwords = XDR_QUADLEN(buf->len - cur); } return len; @@ -1102,47 +1150,6 @@ xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, } EXPORT_SYMBOL_GPL(xdr_buf_subsegment); -/** - * xdr_buf_trim - lop at most "len" bytes off the end of "buf" - * @buf: buf to be trimmed - * @len: number of bytes to reduce "buf" by - * - * Trim an xdr_buf by the given number of bytes by fixing up the lengths. Note - * that it's possible that we'll trim less than that amount if the xdr_buf is - * too small, or if (for instance) it's all in the head and the parser has - * already read too far into it. - */ -void xdr_buf_trim(struct xdr_buf *buf, unsigned int len) -{ - size_t cur; - unsigned int trim = len; - - if (buf->tail[0].iov_len) { - cur = min_t(size_t, buf->tail[0].iov_len, trim); - buf->tail[0].iov_len -= cur; - trim -= cur; - if (!trim) - goto fix_len; - } - - if (buf->page_len) { - cur = min_t(unsigned int, buf->page_len, trim); - buf->page_len -= cur; - trim -= cur; - if (!trim) - goto fix_len; - } - - if (buf->head[0].iov_len) { - cur = min_t(size_t, buf->head[0].iov_len, trim); - buf->head[0].iov_len -= cur; - trim -= cur; - } -fix_len: - buf->len -= (len - trim); -} -EXPORT_SYMBOL_GPL(xdr_buf_trim); - static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) { unsigned int this_len; diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c index f1ec2110efeb..d7117d241460 100644 --- a/net/sunrpc/xprt.c +++ b/net/sunrpc/xprt.c @@ -49,6 +49,7 @@ #include <linux/sunrpc/metrics.h> #include <linux/sunrpc/bc_xprt.h> #include <linux/rcupdate.h> +#include <linux/sched/mm.h> #include <trace/events/sunrpc.h> @@ -643,11 +644,13 @@ static void xprt_autoclose(struct work_struct *work) { struct rpc_xprt *xprt = container_of(work, struct rpc_xprt, task_cleanup); + unsigned int pflags = memalloc_nofs_save(); clear_bit(XPRT_CLOSE_WAIT, &xprt->state); xprt->ops->close(xprt); xprt_release_write(xprt, NULL); wake_up_bit(&xprt->state, XPRT_LOCKED); + memalloc_nofs_restore(pflags); } /** @@ -661,7 +664,7 @@ void xprt_disconnect_done(struct rpc_xprt *xprt) spin_lock_bh(&xprt->transport_lock); xprt_clear_connected(xprt); xprt_clear_write_space_locked(xprt); - xprt_wake_pending_tasks(xprt, -EAGAIN); + xprt_wake_pending_tasks(xprt, -ENOTCONN); spin_unlock_bh(&xprt->transport_lock); } EXPORT_SYMBOL_GPL(xprt_disconnect_done); @@ -1165,6 +1168,7 @@ xprt_request_enqueue_transmit(struct rpc_task *task) /* Note: req is added _before_ pos */ list_add_tail(&req->rq_xmit, &pos->rq_xmit); INIT_LIST_HEAD(&req->rq_xmit2); + trace_xprt_enq_xmit(task, 1); goto out; } } else if (RPC_IS_SWAPPER(task)) { @@ -1176,6 +1180,7 @@ xprt_request_enqueue_transmit(struct rpc_task *task) /* Note: req is added _before_ pos */ list_add_tail(&req->rq_xmit, &pos->rq_xmit); INIT_LIST_HEAD(&req->rq_xmit2); + trace_xprt_enq_xmit(task, 2); goto out; } } else if (!req->rq_seqno) { @@ -1184,11 +1189,13 @@ xprt_request_enqueue_transmit(struct rpc_task *task) continue; list_add_tail(&req->rq_xmit2, &pos->rq_xmit2); INIT_LIST_HEAD(&req->rq_xmit); + trace_xprt_enq_xmit(task, 3); goto out; } } list_add_tail(&req->rq_xmit, &xprt->xmit_queue); INIT_LIST_HEAD(&req->rq_xmit2); + trace_xprt_enq_xmit(task, 4); out: set_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); spin_unlock(&xprt->queue_lock); @@ -1313,8 +1320,6 @@ xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) int is_retrans = RPC_WAS_SENT(task); int status; - dprintk("RPC: %5u xprt_transmit(%u)\n", task->tk_pid, req->rq_slen); - if (!req->rq_bytes_sent) { if (xprt_request_data_received(task)) { status = 0; @@ -1325,6 +1330,13 @@ xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) status = -EBADMSG; goto out_dequeue; } + if (task->tk_ops->rpc_call_prepare_transmit) { + task->tk_ops->rpc_call_prepare_transmit(task, + task->tk_calldata); + status = task->tk_status; + if (status < 0) + goto out_dequeue; + } } /* @@ -1336,9 +1348,9 @@ xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) connect_cookie = xprt->connect_cookie; status = xprt->ops->send_request(req); - trace_xprt_transmit(xprt, req->rq_xid, status); if (status != 0) { req->rq_ntrans--; + trace_xprt_transmit(req, status); return status; } @@ -1347,7 +1359,6 @@ xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) xprt_inject_disconnect(xprt); - dprintk("RPC: %5u xmit complete\n", task->tk_pid); task->tk_flags |= RPC_TASK_SENT; spin_lock_bh(&xprt->transport_lock); @@ -1360,6 +1371,7 @@ xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) req->rq_connect_cookie = connect_cookie; out_dequeue: + trace_xprt_transmit(req, status); xprt_request_dequeue_transmit(task); rpc_wake_up_queued_task_set_status(&xprt->sending, task, status); return status; @@ -1599,7 +1611,6 @@ xprt_request_init(struct rpc_task *task) req->rq_buffer = NULL; req->rq_xid = xprt_alloc_xid(xprt); xprt_init_connect_cookie(req, xprt); - req->rq_bytes_sent = 0; req->rq_snd_buf.len = 0; req->rq_snd_buf.buflen = 0; req->rq_rcv_buf.len = 0; @@ -1721,6 +1732,7 @@ void xprt_release(struct rpc_task *task) xprt->ops->buf_free(task); xprt_inject_disconnect(xprt); xdr_free_bvec(&req->rq_rcv_buf); + xdr_free_bvec(&req->rq_snd_buf); if (req->rq_cred != NULL) put_rpccred(req->rq_cred); task->tk_rqstp = NULL; @@ -1749,7 +1761,6 @@ xprt_init_bc_request(struct rpc_rqst *req, struct rpc_task *task) */ xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + xbufp->tail[0].iov_len; - req->rq_bytes_sent = 0; } #endif diff --git a/net/sunrpc/xprtrdma/backchannel.c b/net/sunrpc/xprtrdma/backchannel.c index 0de9b3e63770..d79b18c1f4cd 100644 --- a/net/sunrpc/xprtrdma/backchannel.c +++ b/net/sunrpc/xprtrdma/backchannel.c @@ -123,7 +123,7 @@ static int rpcrdma_bc_marshal_reply(struct rpc_rqst *rqst) rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); xdr_init_encode(&req->rl_stream, &req->rl_hdrbuf, - req->rl_rdmabuf->rg_base); + req->rl_rdmabuf->rg_base, rqst); p = xdr_reserve_space(&req->rl_stream, 28); if (unlikely(!p)) @@ -267,7 +267,6 @@ void rpcrdma_bc_receive_call(struct rpcrdma_xprt *r_xprt, /* Prepare rqst */ rqst->rq_reply_bytes_recvd = 0; - rqst->rq_bytes_sent = 0; rqst->rq_xid = *p; rqst->rq_private_buf.len = size; diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c index 6a561056b538..52cb6c1b0c2b 100644 --- a/net/sunrpc/xprtrdma/frwr_ops.c +++ b/net/sunrpc/xprtrdma/frwr_ops.c @@ -391,7 +391,7 @@ frwr_wc_localinv_wake(struct ib_cq *cq, struct ib_wc *wc) */ struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, - int nsegs, bool writing, u32 xid, + int nsegs, bool writing, __be32 xid, struct rpcrdma_mr **out) { struct rpcrdma_ia *ia = &r_xprt->rx_ia; @@ -446,7 +446,7 @@ struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, goto out_mapmr_err; ibmr->iova &= 0x00000000ffffffff; - ibmr->iova |= ((u64)cpu_to_be32(xid)) << 32; + ibmr->iova |= ((u64)be32_to_cpu(xid)) << 32; key = (u8)(ibmr->rkey & 0x000000FF); ib_update_fast_reg_key(ibmr, ++key); diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c index d18614e02b4e..6c1fb270f127 100644 --- a/net/sunrpc/xprtrdma/rpc_rdma.c +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -164,6 +164,21 @@ static bool rpcrdma_results_inline(struct rpcrdma_xprt *r_xprt, return rqst->rq_rcv_buf.buflen <= ia->ri_max_inline_read; } +/* The client is required to provide a Reply chunk if the maximum + * size of the non-payload part of the RPC Reply is larger than + * the inline threshold. + */ +static bool +rpcrdma_nonpayload_inline(const struct rpcrdma_xprt *r_xprt, + const struct rpc_rqst *rqst) +{ + const struct xdr_buf *buf = &rqst->rq_rcv_buf; + const struct rpcrdma_ia *ia = &r_xprt->rx_ia; + + return buf->head[0].iov_len + buf->tail[0].iov_len < + ia->ri_max_inline_read; +} + /* Split @vec on page boundaries into SGEs. FMR registers pages, not * a byte range. Other modes coalesce these SGEs into a single MR * when they can. @@ -733,7 +748,7 @@ rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); xdr_init_encode(xdr, &req->rl_hdrbuf, - req->rl_rdmabuf->rg_base); + req->rl_rdmabuf->rg_base, rqst); /* Fixed header fields */ ret = -EMSGSIZE; @@ -762,7 +777,8 @@ rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) */ if (rpcrdma_results_inline(r_xprt, rqst)) wtype = rpcrdma_noch; - else if (ddp_allowed && rqst->rq_rcv_buf.flags & XDRBUF_READ) + else if ((ddp_allowed && rqst->rq_rcv_buf.flags & XDRBUF_READ) && + rpcrdma_nonpayload_inline(r_xprt, rqst)) wtype = rpcrdma_writech; else wtype = rpcrdma_replych; @@ -1313,7 +1329,7 @@ void rpcrdma_reply_handler(struct rpcrdma_rep *rep) /* Fixed transport header fields */ xdr_init_decode(&rep->rr_stream, &rep->rr_hdrbuf, - rep->rr_hdrbuf.head[0].iov_base); + rep->rr_hdrbuf.head[0].iov_base, NULL); p = xdr_inline_decode(&rep->rr_stream, 4 * sizeof(*p)); if (unlikely(!p)) goto out_shortreply; diff --git a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c index b908f2ca08fd..907464c2a9f0 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c +++ b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c @@ -304,7 +304,6 @@ xprt_setup_rdma_bc(struct xprt_create *args) xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO; xprt->prot = XPRT_TRANSPORT_BC_RDMA; - xprt->tsh_size = 0; xprt->ops = &xprt_rdma_bc_procs; memcpy(&xprt->addr, args->dstaddr, args->addrlen); diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c index 828b149eaaef..65e2fb9aac65 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -272,11 +272,8 @@ bool svc_rdma_post_recvs(struct svcxprt_rdma *rdma) return false; ctxt->rc_temp = true; ret = __svc_rdma_post_recv(rdma, ctxt); - if (ret) { - pr_err("svcrdma: failure posting recv buffers: %d\n", - ret); + if (ret) return false; - } } return true; } @@ -314,17 +311,14 @@ static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) spin_lock(&rdma->sc_rq_dto_lock); list_add_tail(&ctxt->rc_list, &rdma->sc_rq_dto_q); - spin_unlock(&rdma->sc_rq_dto_lock); + /* Note the unlock pairs with the smp_rmb in svc_xprt_ready: */ set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); + spin_unlock(&rdma->sc_rq_dto_lock); if (!test_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags)) svc_xprt_enqueue(&rdma->sc_xprt); goto out; flushed: - if (wc->status != IB_WC_WR_FLUSH_ERR) - pr_err("svcrdma: Recv: %s (%u/0x%x)\n", - ib_wc_status_msg(wc->status), - wc->status, wc->vendor_err); post_err: svc_rdma_recv_ctxt_put(rdma, ctxt); set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); diff --git a/net/sunrpc/xprtrdma/svc_rdma_rw.c b/net/sunrpc/xprtrdma/svc_rdma_rw.c index dc1951759a8e..2121c9b4d275 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_rw.c +++ b/net/sunrpc/xprtrdma/svc_rdma_rw.c @@ -64,8 +64,7 @@ svc_rdma_get_rw_ctxt(struct svcxprt_rdma *rdma, unsigned int sges) spin_unlock(&rdma->sc_rw_ctxt_lock); } else { spin_unlock(&rdma->sc_rw_ctxt_lock); - ctxt = kmalloc(sizeof(*ctxt) + - SG_CHUNK_SIZE * sizeof(struct scatterlist), + ctxt = kmalloc(struct_size(ctxt, rw_first_sgl, SG_CHUNK_SIZE), GFP_KERNEL); if (!ctxt) goto out; @@ -213,13 +212,8 @@ static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); wake_up(&rdma->sc_send_wait); - if (unlikely(wc->status != IB_WC_SUCCESS)) { + if (unlikely(wc->status != IB_WC_SUCCESS)) set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); - if (wc->status != IB_WC_WR_FLUSH_ERR) - pr_err("svcrdma: write ctx: %s (%u/0x%x)\n", - ib_wc_status_msg(wc->status), - wc->status, wc->vendor_err); - } svc_rdma_write_info_free(info); } @@ -278,18 +272,15 @@ static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc) if (unlikely(wc->status != IB_WC_SUCCESS)) { set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); - if (wc->status != IB_WC_WR_FLUSH_ERR) - pr_err("svcrdma: read ctx: %s (%u/0x%x)\n", - ib_wc_status_msg(wc->status), - wc->status, wc->vendor_err); svc_rdma_recv_ctxt_put(rdma, info->ri_readctxt); } else { spin_lock(&rdma->sc_rq_dto_lock); list_add_tail(&info->ri_readctxt->rc_list, &rdma->sc_read_complete_q); + /* Note the unlock pairs with the smp_rmb in svc_xprt_ready: */ + set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); spin_unlock(&rdma->sc_rq_dto_lock); - set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); svc_xprt_enqueue(&rdma->sc_xprt); } diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c index cf51b8f9b15f..6fdba72f89f4 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_sendto.c +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -272,10 +272,6 @@ static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) if (unlikely(wc->status != IB_WC_SUCCESS)) { set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); svc_xprt_enqueue(&rdma->sc_xprt); - if (wc->status != IB_WC_WR_FLUSH_ERR) - pr_err("svcrdma: Send: %s (%u/0x%x)\n", - ib_wc_status_msg(wc->status), - wc->status, wc->vendor_err); } svc_xprt_put(&rdma->sc_xprt); @@ -537,6 +533,99 @@ void svc_rdma_sync_reply_hdr(struct svcxprt_rdma *rdma, DMA_TO_DEVICE); } +/* If the xdr_buf has more elements than the device can + * transmit in a single RDMA Send, then the reply will + * have to be copied into a bounce buffer. + */ +static bool svc_rdma_pull_up_needed(struct svcxprt_rdma *rdma, + struct xdr_buf *xdr, + __be32 *wr_lst) +{ + int elements; + + /* xdr->head */ + elements = 1; + + /* xdr->pages */ + if (!wr_lst) { + unsigned int remaining; + unsigned long pageoff; + + pageoff = xdr->page_base & ~PAGE_MASK; + remaining = xdr->page_len; + while (remaining) { + ++elements; + remaining -= min_t(u32, PAGE_SIZE - pageoff, + remaining); + pageoff = 0; + } + } + + /* xdr->tail */ + if (xdr->tail[0].iov_len) + ++elements; + + /* assume 1 SGE is needed for the transport header */ + return elements >= rdma->sc_max_send_sges; +} + +/* The device is not capable of sending the reply directly. + * Assemble the elements of @xdr into the transport header + * buffer. + */ +static int svc_rdma_pull_up_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt, + struct xdr_buf *xdr, __be32 *wr_lst) +{ + unsigned char *dst, *tailbase; + unsigned int taillen; + + dst = ctxt->sc_xprt_buf; + dst += ctxt->sc_sges[0].length; + + memcpy(dst, xdr->head[0].iov_base, xdr->head[0].iov_len); + dst += xdr->head[0].iov_len; + + tailbase = xdr->tail[0].iov_base; + taillen = xdr->tail[0].iov_len; + if (wr_lst) { + u32 xdrpad; + + xdrpad = xdr_padsize(xdr->page_len); + if (taillen && xdrpad) { + tailbase += xdrpad; + taillen -= xdrpad; + } + } else { + unsigned int len, remaining; + unsigned long pageoff; + struct page **ppages; + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + pageoff = xdr->page_base & ~PAGE_MASK; + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - pageoff, remaining); + + memcpy(dst, page_address(*ppages), len); + remaining -= len; + dst += len; + pageoff = 0; + } + } + + if (taillen) + memcpy(dst, tailbase, taillen); + + ctxt->sc_sges[0].length += xdr->len; + ib_dma_sync_single_for_device(rdma->sc_pd->device, + ctxt->sc_sges[0].addr, + ctxt->sc_sges[0].length, + DMA_TO_DEVICE); + + return 0; +} + /* svc_rdma_map_reply_msg - Map the buffer holding RPC message * @rdma: controlling transport * @ctxt: send_ctxt for the Send WR @@ -559,8 +648,10 @@ int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma, u32 xdr_pad; int ret; - if (++ctxt->sc_cur_sge_no >= rdma->sc_max_send_sges) - return -EIO; + if (svc_rdma_pull_up_needed(rdma, xdr, wr_lst)) + return svc_rdma_pull_up_reply_msg(rdma, ctxt, xdr, wr_lst); + + ++ctxt->sc_cur_sge_no; ret = svc_rdma_dma_map_buf(rdma, ctxt, xdr->head[0].iov_base, xdr->head[0].iov_len); @@ -591,8 +682,7 @@ int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma, while (remaining) { len = min_t(u32, PAGE_SIZE - page_off, remaining); - if (++ctxt->sc_cur_sge_no >= rdma->sc_max_send_sges) - return -EIO; + ++ctxt->sc_cur_sge_no; ret = svc_rdma_dma_map_page(rdma, ctxt, *ppages++, page_off, len); if (ret < 0) @@ -606,8 +696,7 @@ int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma, len = xdr->tail[0].iov_len; tail: if (len) { - if (++ctxt->sc_cur_sge_no >= rdma->sc_max_send_sges) - return -EIO; + ++ctxt->sc_cur_sge_no; ret = svc_rdma_dma_map_buf(rdma, ctxt, base, len); if (ret < 0) return ret; diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index 924c17d46903..027a3b07d329 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -390,8 +390,8 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) struct ib_qp_init_attr qp_attr; unsigned int ctxts, rq_depth; struct ib_device *dev; - struct sockaddr *sap; int ret = 0; + RPC_IFDEBUG(struct sockaddr *sap); listen_rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); clear_bit(XPT_CONN, &xprt->xpt_flags); @@ -419,12 +419,9 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) /* Transport header, head iovec, tail iovec */ newxprt->sc_max_send_sges = 3; /* Add one SGE per page list entry */ - newxprt->sc_max_send_sges += svcrdma_max_req_size / PAGE_SIZE; - if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge) { - pr_err("svcrdma: too few Send SGEs available (%d needed)\n", - newxprt->sc_max_send_sges); - goto errout; - } + newxprt->sc_max_send_sges += (svcrdma_max_req_size / PAGE_SIZE) + 1; + if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge) + newxprt->sc_max_send_sges = dev->attrs.max_send_sge; newxprt->sc_max_req_size = svcrdma_max_req_size; newxprt->sc_max_requests = svcrdma_max_requests; newxprt->sc_max_bc_requests = svcrdma_max_bc_requests; @@ -528,6 +525,7 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) if (ret) goto errout; +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) dprintk("svcrdma: new connection %p accepted:\n", newxprt); sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; dprintk(" local address : %pIS:%u\n", sap, rpc_get_port(sap)); @@ -538,6 +536,7 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) dprintk(" rdma_rw_ctxs : %d\n", ctxts); dprintk(" max_requests : %d\n", newxprt->sc_max_requests); dprintk(" ord : %d\n", conn_param.initiator_depth); +#endif trace_svcrdma_xprt_accept(&newxprt->sc_xprt); return &newxprt->sc_xprt; @@ -591,11 +590,6 @@ static void __svc_rdma_free(struct work_struct *work) if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) ib_drain_qp(rdma->sc_qp); - /* We should only be called from kref_put */ - if (kref_read(&xprt->xpt_ref) != 0) - pr_err("svcrdma: sc_xprt still in use? (%d)\n", - kref_read(&xprt->xpt_ref)); - svc_rdma_flush_recv_queues(rdma); /* Final put of backchannel client transport */ diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c index fbc171ebfe91..5d261353bd90 100644 --- a/net/sunrpc/xprtrdma/transport.c +++ b/net/sunrpc/xprtrdma/transport.c @@ -332,7 +332,6 @@ xprt_setup_rdma(struct xprt_create *args) xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO; xprt->resvport = 0; /* privileged port not needed */ - xprt->tsh_size = 0; /* RPC-RDMA handles framing */ xprt->ops = &xprt_rdma_procs; /* @@ -738,7 +737,6 @@ xprt_rdma_send_request(struct rpc_rqst *rqst) goto drop_connection; rqst->rq_xmit_bytes_sent += rqst->rq_snd_buf.len; - rqst->rq_bytes_sent = 0; /* An RPC with no reply will throw off credit accounting, * so drop the connection to reset the credit grant. diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c index 4994e75945b8..89a63391d4d4 100644 --- a/net/sunrpc/xprtrdma/verbs.c +++ b/net/sunrpc/xprtrdma/verbs.c @@ -527,7 +527,8 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, sendcq = ib_alloc_cq(ia->ri_device, NULL, ep->rep_attr.cap.max_send_wr + 1, - 1, IB_POLL_WORKQUEUE); + ia->ri_device->num_comp_vectors > 1 ? 1 : 0, + IB_POLL_WORKQUEUE); if (IS_ERR(sendcq)) { rc = PTR_ERR(sendcq); goto out1; @@ -1480,6 +1481,8 @@ rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, bool temp) if (ep->rep_receive_count > needed) goto out; needed -= ep->rep_receive_count; + if (!temp) + needed += RPCRDMA_MAX_RECV_BATCH; count = 0; wr = NULL; diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h index 5a18472f2c9c..10f6593e1a6a 100644 --- a/net/sunrpc/xprtrdma/xprt_rdma.h +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -205,6 +205,16 @@ struct rpcrdma_rep { struct ib_recv_wr rr_recv_wr; }; +/* To reduce the rate at which a transport invokes ib_post_recv + * (and thus the hardware doorbell rate), xprtrdma posts Receive + * WRs in batches. + * + * Setting this to zero disables Receive post batching. + */ +enum { + RPCRDMA_MAX_RECV_BATCH = 7, +}; + /* struct rpcrdma_sendctx - DMA mapped SGEs to unmap after Send completes */ struct rpcrdma_req; @@ -577,7 +587,7 @@ void frwr_release_mr(struct rpcrdma_mr *mr); size_t frwr_maxpages(struct rpcrdma_xprt *r_xprt); struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, - int nsegs, bool writing, u32 xid, + int nsegs, bool writing, __be32 xid, struct rpcrdma_mr **mr); int frwr_send(struct rpcrdma_ia *ia, struct rpcrdma_req *req); void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs); diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 7754aa3e434f..9359539907ba 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -50,6 +50,7 @@ #include <linux/bvec.h> #include <linux/highmem.h> #include <linux/uio.h> +#include <linux/sched/mm.h> #include <trace/events/sunrpc.h> @@ -404,8 +405,8 @@ xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, size_t want, seek_init = seek, offset = 0; ssize_t ret; - if (seek < buf->head[0].iov_len) { - want = min_t(size_t, count, buf->head[0].iov_len); + want = min_t(size_t, count, buf->head[0].iov_len); + if (seek < want) { ret = xs_read_kvec(sock, msg, flags, &buf->head[0], want, seek); if (ret <= 0) goto sock_err; @@ -416,13 +417,13 @@ xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, goto out; seek = 0; } else { - seek -= buf->head[0].iov_len; - offset += buf->head[0].iov_len; + seek -= want; + offset += want; } want = xs_alloc_sparse_pages(buf, min_t(size_t, count - offset, buf->page_len), - GFP_NOWAIT); + GFP_KERNEL); if (seek < want) { ret = xs_read_bvec(sock, msg, flags, buf->bvec, xdr_buf_pagecount(buf), @@ -442,8 +443,8 @@ xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, offset += want; } - if (seek < buf->tail[0].iov_len) { - want = min_t(size_t, count - offset, buf->tail[0].iov_len); + want = min_t(size_t, count - offset, buf->tail[0].iov_len); + if (seek < want) { ret = xs_read_kvec(sock, msg, flags, &buf->tail[0], want, seek); if (ret <= 0) goto sock_err; @@ -452,8 +453,8 @@ xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, goto out; if (ret != want) goto out; - } else - offset += buf->tail[0].iov_len; + } else if (offset < seek_init) + offset = seek_init; ret = -EMSGSIZE; out: *read = offset - seek_init; @@ -481,6 +482,14 @@ xs_read_stream_request_done(struct sock_xprt *transport) return transport->recv.fraghdr & cpu_to_be32(RPC_LAST_STREAM_FRAGMENT); } +static void +xs_read_stream_check_eor(struct sock_xprt *transport, + struct msghdr *msg) +{ + if (xs_read_stream_request_done(transport)) + msg->msg_flags |= MSG_EOR; +} + static ssize_t xs_read_stream_request(struct sock_xprt *transport, struct msghdr *msg, int flags, struct rpc_rqst *req) @@ -492,17 +501,21 @@ xs_read_stream_request(struct sock_xprt *transport, struct msghdr *msg, xs_read_header(transport, buf); want = transport->recv.len - transport->recv.offset; - ret = xs_read_xdr_buf(transport->sock, msg, flags, buf, - transport->recv.copied + want, transport->recv.copied, - &read); - transport->recv.offset += read; - transport->recv.copied += read; - if (transport->recv.offset == transport->recv.len) { - if (xs_read_stream_request_done(transport)) - msg->msg_flags |= MSG_EOR; - return read; + if (want != 0) { + ret = xs_read_xdr_buf(transport->sock, msg, flags, buf, + transport->recv.copied + want, + transport->recv.copied, + &read); + transport->recv.offset += read; + transport->recv.copied += read; } + if (transport->recv.offset == transport->recv.len) + xs_read_stream_check_eor(transport, msg); + + if (want == 0) + return 0; + switch (ret) { default: break; @@ -655,13 +668,35 @@ out_err: return ret != 0 ? ret : -ESHUTDOWN; } +static __poll_t xs_poll_socket(struct sock_xprt *transport) +{ + return transport->sock->ops->poll(transport->file, transport->sock, + NULL); +} + +static bool xs_poll_socket_readable(struct sock_xprt *transport) +{ + __poll_t events = xs_poll_socket(transport); + + return (events & (EPOLLIN | EPOLLRDNORM)) && !(events & EPOLLRDHUP); +} + +static void xs_poll_check_readable(struct sock_xprt *transport) +{ + + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + if (!xs_poll_socket_readable(transport)) + return; + if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); +} + static void xs_stream_data_receive(struct sock_xprt *transport) { size_t read = 0; ssize_t ret = 0; mutex_lock(&transport->recv_mutex); - clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); if (transport->sock == NULL) goto out; for (;;) { @@ -671,6 +706,10 @@ static void xs_stream_data_receive(struct sock_xprt *transport) read += ret; cond_resched(); } + if (ret == -ESHUTDOWN) + kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else + xs_poll_check_readable(transport); out: mutex_unlock(&transport->recv_mutex); trace_xs_stream_read_data(&transport->xprt, ret, read); @@ -680,7 +719,10 @@ static void xs_stream_data_receive_workfn(struct work_struct *work) { struct sock_xprt *transport = container_of(work, struct sock_xprt, recv_worker); + unsigned int pflags = memalloc_nofs_save(); + xs_stream_data_receive(transport); + memalloc_nofs_restore(pflags); } static void @@ -690,65 +732,65 @@ xs_stream_reset_connect(struct sock_xprt *transport) transport->recv.len = 0; transport->recv.copied = 0; transport->xmit.offset = 0; +} + +static void +xs_stream_start_connect(struct sock_xprt *transport) +{ transport->xprt.stat.connect_count++; transport->xprt.stat.connect_start = jiffies; } #define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) -static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen, struct kvec *vec, unsigned int base, int more) +static int xs_sendmsg(struct socket *sock, struct msghdr *msg, size_t seek) { - struct msghdr msg = { - .msg_name = addr, - .msg_namelen = addrlen, - .msg_flags = XS_SENDMSG_FLAGS | (more ? MSG_MORE : 0), - }; - struct kvec iov = { - .iov_base = vec->iov_base + base, - .iov_len = vec->iov_len - base, - }; + if (seek) + iov_iter_advance(&msg->msg_iter, seek); + return sock_sendmsg(sock, msg); +} - if (iov.iov_len != 0) - return kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len); - return kernel_sendmsg(sock, &msg, NULL, 0, 0); +static int xs_send_kvec(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t seek) +{ + iov_iter_kvec(&msg->msg_iter, WRITE, vec, 1, vec->iov_len); + return xs_sendmsg(sock, msg, seek); } -static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more, bool zerocopy, int *sent_p) +static int xs_send_pagedata(struct socket *sock, struct msghdr *msg, struct xdr_buf *xdr, size_t base) { - ssize_t (*do_sendpage)(struct socket *sock, struct page *page, - int offset, size_t size, int flags); - struct page **ppage; - unsigned int remainder; int err; - remainder = xdr->page_len - base; - base += xdr->page_base; - ppage = xdr->pages + (base >> PAGE_SHIFT); - base &= ~PAGE_MASK; - do_sendpage = sock->ops->sendpage; - if (!zerocopy) - do_sendpage = sock_no_sendpage; - for(;;) { - unsigned int len = min_t(unsigned int, PAGE_SIZE - base, remainder); - int flags = XS_SENDMSG_FLAGS; + err = xdr_alloc_bvec(xdr, GFP_KERNEL); + if (err < 0) + return err; - remainder -= len; - if (more) - flags |= MSG_MORE; - if (remainder != 0) - flags |= MSG_SENDPAGE_NOTLAST | MSG_MORE; - err = do_sendpage(sock, *ppage, base, len, flags); - if (remainder == 0 || err != len) - break; - *sent_p += err; - ppage++; - base = 0; - } - if (err > 0) { - *sent_p += err; - err = 0; - } - return err; + iov_iter_bvec(&msg->msg_iter, WRITE, xdr->bvec, + xdr_buf_pagecount(xdr), + xdr->page_len + xdr->page_base); + return xs_sendmsg(sock, msg, base + xdr->page_base); +} + +#define xs_record_marker_len() sizeof(rpc_fraghdr) + +/* Common case: + * - stream transport + * - sending from byte 0 of the message + * - the message is wholly contained in @xdr's head iovec + */ +static int xs_send_rm_and_kvec(struct socket *sock, struct msghdr *msg, + rpc_fraghdr marker, struct kvec *vec, size_t base) +{ + struct kvec iov[2] = { + [0] = { + .iov_base = &marker, + .iov_len = sizeof(marker) + }, + [1] = *vec, + }; + size_t len = iov[0].iov_len + iov[1].iov_len; + + iov_iter_kvec(&msg->msg_iter, WRITE, iov, 2, len); + return xs_sendmsg(sock, msg, base); } /** @@ -758,49 +800,60 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i * @addrlen: UDP only -- length of destination address * @xdr: buffer containing this request * @base: starting position in the buffer - * @zerocopy: true if it is safe to use sendpage() + * @rm: stream record marker field * @sent_p: return the total number of bytes successfully queued for sending * */ -static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, bool zerocopy, int *sent_p) +static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, rpc_fraghdr rm, int *sent_p) { - unsigned int remainder = xdr->len - base; + struct msghdr msg = { + .msg_name = addr, + .msg_namelen = addrlen, + .msg_flags = XS_SENDMSG_FLAGS | MSG_MORE, + }; + unsigned int rmsize = rm ? sizeof(rm) : 0; + unsigned int remainder = rmsize + xdr->len - base; + unsigned int want; int err = 0; - int sent = 0; if (unlikely(!sock)) return -ENOTSOCK; - if (base != 0) { - addr = NULL; - addrlen = 0; - } - - if (base < xdr->head[0].iov_len || addr != NULL) { - unsigned int len = xdr->head[0].iov_len - base; + want = xdr->head[0].iov_len + rmsize; + if (base < want) { + unsigned int len = want - base; remainder -= len; - err = xs_send_kvec(sock, addr, addrlen, &xdr->head[0], base, remainder != 0); + if (remainder == 0) + msg.msg_flags &= ~MSG_MORE; + if (rmsize) + err = xs_send_rm_and_kvec(sock, &msg, rm, + &xdr->head[0], base); + else + err = xs_send_kvec(sock, &msg, &xdr->head[0], base); if (remainder == 0 || err != len) goto out; *sent_p += err; base = 0; } else - base -= xdr->head[0].iov_len; + base -= want; if (base < xdr->page_len) { unsigned int len = xdr->page_len - base; remainder -= len; - err = xs_send_pagedata(sock, xdr, base, remainder != 0, zerocopy, &sent); - *sent_p += sent; - if (remainder == 0 || sent != len) + if (remainder == 0) + msg.msg_flags &= ~MSG_MORE; + err = xs_send_pagedata(sock, &msg, xdr, base); + if (remainder == 0 || err != len) goto out; + *sent_p += err; base = 0; } else base -= xdr->page_len; if (base >= xdr->tail[0].iov_len) return 0; - err = xs_send_kvec(sock, NULL, 0, &xdr->tail[0], base, 0); + msg.msg_flags &= ~MSG_MORE; + err = xs_send_kvec(sock, &msg, &xdr->tail[0], base); out: if (err > 0) { *sent_p += err; @@ -856,7 +909,7 @@ static int xs_nospace(struct rpc_rqst *req) static void xs_stream_prepare_request(struct rpc_rqst *req) { - req->rq_task->tk_status = xdr_alloc_bvec(&req->rq_rcv_buf, GFP_NOIO); + req->rq_task->tk_status = xdr_alloc_bvec(&req->rq_rcv_buf, GFP_KERNEL); } /* @@ -870,13 +923,14 @@ xs_send_request_was_aborted(struct sock_xprt *transport, struct rpc_rqst *req) } /* - * Construct a stream transport record marker in @buf. + * Return the stream record marker field for a record of length < 2^31-1 */ -static inline void xs_encode_stream_record_marker(struct xdr_buf *buf) +static rpc_fraghdr +xs_stream_record_marker(struct xdr_buf *xdr) { - u32 reclen = buf->len - sizeof(rpc_fraghdr); - rpc_fraghdr *base = buf->head[0].iov_base; - *base = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | reclen); + if (!xdr->len) + return 0; + return cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | (u32)xdr->len); } /** @@ -905,15 +959,14 @@ static int xs_local_send_request(struct rpc_rqst *req) return -ENOTCONN; } - xs_encode_stream_record_marker(&req->rq_snd_buf); - xs_pktdump("packet data:", req->rq_svec->iov_base, req->rq_svec->iov_len); req->rq_xtime = ktime_get(); status = xs_sendpages(transport->sock, NULL, 0, xdr, transport->xmit.offset, - true, &sent); + xs_stream_record_marker(xdr), + &sent); dprintk("RPC: %s(%u) = %d\n", __func__, xdr->len - transport->xmit.offset, status); @@ -925,7 +978,6 @@ static int xs_local_send_request(struct rpc_rqst *req) req->rq_bytes_sent = transport->xmit.offset; if (likely(req->rq_bytes_sent >= req->rq_slen)) { req->rq_xmit_bytes_sent += transport->xmit.offset; - req->rq_bytes_sent = 0; transport->xmit.offset = 0; return 0; } @@ -981,7 +1033,7 @@ static int xs_udp_send_request(struct rpc_rqst *req) req->rq_xtime = ktime_get(); status = xs_sendpages(transport->sock, xs_addr(xprt), xprt->addrlen, - xdr, 0, true, &sent); + xdr, 0, 0, &sent); dprintk("RPC: xs_udp_send_request(%u) = %d\n", xdr->len, status); @@ -1045,7 +1097,6 @@ static int xs_tcp_send_request(struct rpc_rqst *req) struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; - bool zerocopy = true; bool vm_wait = false; int status; int sent; @@ -1057,17 +1108,9 @@ static int xs_tcp_send_request(struct rpc_rqst *req) return -ENOTCONN; } - xs_encode_stream_record_marker(&req->rq_snd_buf); - xs_pktdump("packet data:", req->rq_svec->iov_base, req->rq_svec->iov_len); - /* Don't use zero copy if this is a resend. If the RPC call - * completes while the socket holds a reference to the pages, - * then we may end up resending corrupted data. - */ - if (req->rq_task->tk_flags & RPC_TASK_SENT) - zerocopy = false; if (test_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state)) xs_tcp_set_socket_timeouts(xprt, transport->sock); @@ -1080,7 +1123,8 @@ static int xs_tcp_send_request(struct rpc_rqst *req) sent = 0; status = xs_sendpages(transport->sock, NULL, 0, xdr, transport->xmit.offset, - zerocopy, &sent); + xs_stream_record_marker(xdr), + &sent); dprintk("RPC: xs_tcp_send_request(%u) = %d\n", xdr->len - transport->xmit.offset, status); @@ -1091,7 +1135,6 @@ static int xs_tcp_send_request(struct rpc_rqst *req) req->rq_bytes_sent = transport->xmit.offset; if (likely(req->rq_bytes_sent >= req->rq_slen)) { req->rq_xmit_bytes_sent += transport->xmit.offset; - req->rq_bytes_sent = 0; transport->xmit.offset = 0; return 0; } @@ -1211,6 +1254,7 @@ static void xs_reset_transport(struct sock_xprt *transport) struct socket *sock = transport->sock; struct sock *sk = transport->inet; struct rpc_xprt *xprt = &transport->xprt; + struct file *filp = transport->file; if (sk == NULL) return; @@ -1224,6 +1268,7 @@ static void xs_reset_transport(struct sock_xprt *transport) write_lock_bh(&sk->sk_callback_lock); transport->inet = NULL; transport->sock = NULL; + transport->file = NULL; sk->sk_user_data = NULL; @@ -1231,10 +1276,12 @@ static void xs_reset_transport(struct sock_xprt *transport) xprt_clear_connected(xprt); write_unlock_bh(&sk->sk_callback_lock); xs_sock_reset_connection_flags(xprt); + /* Reset stream record info */ + xs_stream_reset_connect(transport); mutex_unlock(&transport->recv_mutex); trace_rpc_socket_close(xprt, sock); - sock_release(sock); + fput(filp); xprt_disconnect_done(xprt); } @@ -1358,7 +1405,6 @@ static void xs_udp_data_receive(struct sock_xprt *transport) int err; mutex_lock(&transport->recv_mutex); - clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); sk = transport->inet; if (sk == NULL) goto out; @@ -1370,6 +1416,7 @@ static void xs_udp_data_receive(struct sock_xprt *transport) consume_skb(skb); cond_resched(); } + xs_poll_check_readable(transport); out: mutex_unlock(&transport->recv_mutex); } @@ -1378,7 +1425,10 @@ static void xs_udp_data_receive_workfn(struct work_struct *work) { struct sock_xprt *transport = container_of(work, struct sock_xprt, recv_worker); + unsigned int pflags = memalloc_nofs_save(); + xs_udp_data_receive(transport); + memalloc_nofs_restore(pflags); } /** @@ -1826,6 +1876,7 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, struct sock_xprt *transport, int family, int type, int protocol, bool reuseport) { + struct file *filp; struct socket *sock; int err; @@ -1846,6 +1897,11 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt, goto out; } + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); + if (IS_ERR(filp)) + return ERR_CAST(filp); + transport->file = filp; + return sock; out: return ERR_PTR(err); @@ -1869,7 +1925,6 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, sk->sk_write_space = xs_udp_write_space; sock_set_flag(sk, SOCK_FASYNC); sk->sk_error_report = xs_error_report; - sk->sk_allocation = GFP_NOIO; xprt_clear_connected(xprt); @@ -1880,7 +1935,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, write_unlock_bh(&sk->sk_callback_lock); } - xs_stream_reset_connect(transport); + xs_stream_start_connect(transport); return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); } @@ -1892,6 +1947,7 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, static int xs_local_setup_socket(struct sock_xprt *transport) { struct rpc_xprt *xprt = &transport->xprt; + struct file *filp; struct socket *sock; int status = -EIO; @@ -1904,6 +1960,13 @@ static int xs_local_setup_socket(struct sock_xprt *transport) } xs_reclassify_socket(AF_LOCAL, sock); + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); + if (IS_ERR(filp)) { + status = PTR_ERR(filp); + goto out; + } + transport->file = filp; + dprintk("RPC: worker connecting xprt %p via AF_LOCAL to %s\n", xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); @@ -2057,7 +2120,6 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) sk->sk_data_ready = xs_data_ready; sk->sk_write_space = xs_udp_write_space; sock_set_flag(sk, SOCK_FASYNC); - sk->sk_allocation = GFP_NOIO; xprt_set_connected(xprt); @@ -2220,7 +2282,6 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) sk->sk_write_space = xs_tcp_write_space; sock_set_flag(sk, SOCK_FASYNC); sk->sk_error_report = xs_error_report; - sk->sk_allocation = GFP_NOIO; /* socket options */ sock_reset_flag(sk, SOCK_LINGER); @@ -2240,8 +2301,7 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) xs_set_memalloc(xprt); - /* Reset TCP record info */ - xs_stream_reset_connect(transport); + xs_stream_start_connect(transport); /* Tell the socket layer to start connecting... */ set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); @@ -2534,26 +2594,35 @@ static int bc_sendto(struct rpc_rqst *req) { int len; struct xdr_buf *xbufp = &req->rq_snd_buf; - struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct socket *sock = transport->sock; + container_of(req->rq_xprt, struct sock_xprt, xprt); unsigned long headoff; unsigned long tailoff; + struct page *tailpage; + struct msghdr msg = { + .msg_flags = MSG_MORE + }; + rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | + (u32)xbufp->len); + struct kvec iov = { + .iov_base = &marker, + .iov_len = sizeof(marker), + }; - xs_encode_stream_record_marker(xbufp); + len = kernel_sendmsg(transport->sock, &msg, &iov, 1, iov.iov_len); + if (len != iov.iov_len) + return -EAGAIN; + tailpage = NULL; + if (xbufp->tail[0].iov_len) + tailpage = virt_to_page(xbufp->tail[0].iov_base); tailoff = (unsigned long)xbufp->tail[0].iov_base & ~PAGE_MASK; headoff = (unsigned long)xbufp->head[0].iov_base & ~PAGE_MASK; - len = svc_send_common(sock, xbufp, + len = svc_send_common(transport->sock, xbufp, virt_to_page(xbufp->head[0].iov_base), headoff, - xbufp->tail[0].iov_base, tailoff); - - if (len != xbufp->len) { - printk(KERN_NOTICE "Error sending entire callback!\n"); - len = -EAGAIN; - } - + tailpage, tailoff); + if (len != xbufp->len) + return -EAGAIN; return len; } @@ -2793,7 +2862,6 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = 0; - xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->bind_timeout = XS_BIND_TO; @@ -2862,7 +2930,6 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = IPPROTO_UDP; - xprt->tsh_size = 0; /* XXX: header size can vary due to auth type, IPv6, etc. */ xprt->max_payload = (1U << 16) - (MAX_HEADER << 3); @@ -2942,7 +3009,6 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = IPPROTO_TCP; - xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->bind_timeout = XS_BIND_TO; @@ -3015,7 +3081,6 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) transport = container_of(xprt, struct sock_xprt, xprt); xprt->prot = IPPROTO_TCP; - xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; xprt->timeout = &xs_tcp_default_timeout; diff --git a/net/switchdev/switchdev.c b/net/switchdev/switchdev.c index 5df9d1138ac9..90ba4a1f0a6d 100644 --- a/net/switchdev/switchdev.c +++ b/net/switchdev/switchdev.c @@ -23,78 +23,6 @@ #include <linux/rtnetlink.h> #include <net/switchdev.h> -/** - * switchdev_trans_item_enqueue - Enqueue data item to transaction queue - * - * @trans: transaction - * @data: pointer to data being queued - * @destructor: data destructor - * @tritem: transaction item being queued - * - * Enqeueue data item to transaction queue. tritem is typically placed in - * cointainter pointed at by data pointer. Destructor is called on - * transaction abort and after successful commit phase in case - * the caller did not dequeue the item before. - */ -void switchdev_trans_item_enqueue(struct switchdev_trans *trans, - void *data, void (*destructor)(void const *), - struct switchdev_trans_item *tritem) -{ - tritem->data = data; - tritem->destructor = destructor; - list_add_tail(&tritem->list, &trans->item_list); -} -EXPORT_SYMBOL_GPL(switchdev_trans_item_enqueue); - -static struct switchdev_trans_item * -__switchdev_trans_item_dequeue(struct switchdev_trans *trans) -{ - struct switchdev_trans_item *tritem; - - if (list_empty(&trans->item_list)) - return NULL; - tritem = list_first_entry(&trans->item_list, - struct switchdev_trans_item, list); - list_del(&tritem->list); - return tritem; -} - -/** - * switchdev_trans_item_dequeue - Dequeue data item from transaction queue - * - * @trans: transaction - */ -void *switchdev_trans_item_dequeue(struct switchdev_trans *trans) -{ - struct switchdev_trans_item *tritem; - - tritem = __switchdev_trans_item_dequeue(trans); - BUG_ON(!tritem); - return tritem->data; -} -EXPORT_SYMBOL_GPL(switchdev_trans_item_dequeue); - -static void switchdev_trans_init(struct switchdev_trans *trans) -{ - INIT_LIST_HEAD(&trans->item_list); -} - -static void switchdev_trans_items_destroy(struct switchdev_trans *trans) -{ - struct switchdev_trans_item *tritem; - - while ((tritem = __switchdev_trans_item_dequeue(trans))) - tritem->destructor(tritem->data); -} - -static void switchdev_trans_items_warn_destroy(struct net_device *dev, - struct switchdev_trans *trans) -{ - WARN(!list_empty(&trans->item_list), "%s: transaction item queue is not empty.\n", - dev->name); - switchdev_trans_items_destroy(trans); -} - static LIST_HEAD(deferred); static DEFINE_SPINLOCK(deferred_lock); @@ -174,81 +102,32 @@ static int switchdev_deferred_enqueue(struct net_device *dev, return 0; } -/** - * switchdev_port_attr_get - Get port attribute - * - * @dev: port device - * @attr: attribute to get - */ -int switchdev_port_attr_get(struct net_device *dev, struct switchdev_attr *attr) +static int switchdev_port_attr_notify(enum switchdev_notifier_type nt, + struct net_device *dev, + const struct switchdev_attr *attr, + struct switchdev_trans *trans) { - const struct switchdev_ops *ops = dev->switchdev_ops; - struct net_device *lower_dev; - struct list_head *iter; - struct switchdev_attr first = { - .id = SWITCHDEV_ATTR_ID_UNDEFINED - }; - int err = -EOPNOTSUPP; + int err; + int rc; - if (ops && ops->switchdev_port_attr_get) - return ops->switchdev_port_attr_get(dev, attr); + struct switchdev_notifier_port_attr_info attr_info = { + .attr = attr, + .trans = trans, + .handled = false, + }; - if (attr->flags & SWITCHDEV_F_NO_RECURSE) + rc = call_switchdev_blocking_notifiers(nt, dev, + &attr_info.info, NULL); + err = notifier_to_errno(rc); + if (err) { + WARN_ON(!attr_info.handled); return err; - - /* Switch device port(s) may be stacked under - * bond/team/vlan dev, so recurse down to get attr on - * each port. Return -ENODATA if attr values don't - * compare across ports. - */ - - netdev_for_each_lower_dev(dev, lower_dev, iter) { - err = switchdev_port_attr_get(lower_dev, attr); - if (err) - break; - if (first.id == SWITCHDEV_ATTR_ID_UNDEFINED) - first = *attr; - else if (memcmp(&first, attr, sizeof(*attr))) - return -ENODATA; - } - - return err; -} -EXPORT_SYMBOL_GPL(switchdev_port_attr_get); - -static int __switchdev_port_attr_set(struct net_device *dev, - const struct switchdev_attr *attr, - struct switchdev_trans *trans) -{ - const struct switchdev_ops *ops = dev->switchdev_ops; - struct net_device *lower_dev; - struct list_head *iter; - int err = -EOPNOTSUPP; - - if (ops && ops->switchdev_port_attr_set) { - err = ops->switchdev_port_attr_set(dev, attr, trans); - goto done; } - if (attr->flags & SWITCHDEV_F_NO_RECURSE) - goto done; - - /* Switch device port(s) may be stacked under - * bond/team/vlan dev, so recurse down to set attr on - * each port. - */ - - netdev_for_each_lower_dev(dev, lower_dev, iter) { - err = __switchdev_port_attr_set(lower_dev, attr, trans); - if (err) - break; - } - -done: - if (err == -EOPNOTSUPP && attr->flags & SWITCHDEV_F_SKIP_EOPNOTSUPP) - err = 0; + if (!attr_info.handled) + return -EOPNOTSUPP; - return err; + return 0; } static int switchdev_port_attr_set_now(struct net_device *dev, @@ -257,8 +136,6 @@ static int switchdev_port_attr_set_now(struct net_device *dev, struct switchdev_trans trans; int err; - switchdev_trans_init(&trans); - /* Phase I: prepare for attr set. Driver/device should fail * here if there are going to be issues in the commit phase, * such as lack of resources or support. The driver/device @@ -267,18 +144,10 @@ static int switchdev_port_attr_set_now(struct net_device *dev, */ trans.ph_prepare = true; - err = __switchdev_port_attr_set(dev, attr, &trans); - if (err) { - /* Prepare phase failed: abort the transaction. Any - * resources reserved in the prepare phase are - * released. - */ - - if (err != -EOPNOTSUPP) - switchdev_trans_items_destroy(&trans); - + err = switchdev_port_attr_notify(SWITCHDEV_PORT_ATTR_SET, dev, attr, + &trans); + if (err) return err; - } /* Phase II: commit attr set. This cannot fail as a fault * of driver/device. If it does, it's a bug in the driver/device @@ -286,10 +155,10 @@ static int switchdev_port_attr_set_now(struct net_device *dev, */ trans.ph_prepare = false; - err = __switchdev_port_attr_set(dev, attr, &trans); + err = switchdev_port_attr_notify(SWITCHDEV_PORT_ATTR_SET, dev, attr, + &trans); WARN(err, "%s: Commit of attribute (id=%d) failed.\n", dev->name, attr->id); - switchdev_trans_items_warn_destroy(dev, &trans); return err; } @@ -388,8 +257,6 @@ static int switchdev_port_obj_add_now(struct net_device *dev, ASSERT_RTNL(); - switchdev_trans_init(&trans); - /* Phase I: prepare for obj add. Driver/device should fail * here if there are going to be issues in the commit phase, * such as lack of resources or support. The driver/device @@ -400,17 +267,8 @@ static int switchdev_port_obj_add_now(struct net_device *dev, trans.ph_prepare = true; err = switchdev_port_obj_notify(SWITCHDEV_PORT_OBJ_ADD, dev, obj, &trans, extack); - if (err) { - /* Prepare phase failed: abort the transaction. Any - * resources reserved in the prepare phase are - * released. - */ - - if (err != -EOPNOTSUPP) - switchdev_trans_items_destroy(&trans); - + if (err) return err; - } /* Phase II: commit obj add. This cannot fail as a fault * of driver/device. If it does, it's a bug in the driver/device @@ -421,7 +279,6 @@ static int switchdev_port_obj_add_now(struct net_device *dev, err = switchdev_port_obj_notify(SWITCHDEV_PORT_OBJ_ADD, dev, obj, &trans, extack); WARN(err, "%s: Commit of object (id=%d) failed.\n", dev->name, obj->id); - switchdev_trans_items_warn_destroy(dev, &trans); return err; } @@ -556,10 +413,11 @@ EXPORT_SYMBOL_GPL(unregister_switchdev_notifier); * Call all network notifier blocks. */ int call_switchdev_notifiers(unsigned long val, struct net_device *dev, - struct switchdev_notifier_info *info) + struct switchdev_notifier_info *info, + struct netlink_ext_ack *extack) { info->dev = dev; - info->extack = NULL; + info->extack = extack; return atomic_notifier_call_chain(&switchdev_notif_chain, val, info); } EXPORT_SYMBOL_GPL(call_switchdev_notifiers); @@ -591,26 +449,6 @@ int call_switchdev_blocking_notifiers(unsigned long val, struct net_device *dev, } EXPORT_SYMBOL_GPL(call_switchdev_blocking_notifiers); -bool switchdev_port_same_parent_id(struct net_device *a, - struct net_device *b) -{ - struct switchdev_attr a_attr = { - .orig_dev = a, - .id = SWITCHDEV_ATTR_ID_PORT_PARENT_ID, - }; - struct switchdev_attr b_attr = { - .orig_dev = b, - .id = SWITCHDEV_ATTR_ID_PORT_PARENT_ID, - }; - - if (switchdev_port_attr_get(a, &a_attr) || - switchdev_port_attr_get(b, &b_attr)) - return false; - - return netdev_phys_item_id_same(&a_attr.u.ppid, &b_attr.u.ppid); -} -EXPORT_SYMBOL_GPL(switchdev_port_same_parent_id); - static int __switchdev_handle_port_obj_add(struct net_device *dev, struct switchdev_notifier_port_obj_info *port_obj_info, bool (*check_cb)(const struct net_device *dev), @@ -716,3 +554,54 @@ int switchdev_handle_port_obj_del(struct net_device *dev, return err; } EXPORT_SYMBOL_GPL(switchdev_handle_port_obj_del); + +static int __switchdev_handle_port_attr_set(struct net_device *dev, + struct switchdev_notifier_port_attr_info *port_attr_info, + bool (*check_cb)(const struct net_device *dev), + int (*set_cb)(struct net_device *dev, + const struct switchdev_attr *attr, + struct switchdev_trans *trans)) +{ + struct net_device *lower_dev; + struct list_head *iter; + int err = -EOPNOTSUPP; + + if (check_cb(dev)) { + port_attr_info->handled = true; + return set_cb(dev, port_attr_info->attr, + port_attr_info->trans); + } + + /* Switch ports might be stacked under e.g. a LAG. Ignore the + * unsupported devices, another driver might be able to handle them. But + * propagate to the callers any hard errors. + * + * If the driver does its own bookkeeping of stacked ports, it's not + * necessary to go through this helper. + */ + netdev_for_each_lower_dev(dev, lower_dev, iter) { + err = __switchdev_handle_port_attr_set(lower_dev, port_attr_info, + check_cb, set_cb); + if (err && err != -EOPNOTSUPP) + return err; + } + + return err; +} + +int switchdev_handle_port_attr_set(struct net_device *dev, + struct switchdev_notifier_port_attr_info *port_attr_info, + bool (*check_cb)(const struct net_device *dev), + int (*set_cb)(struct net_device *dev, + const struct switchdev_attr *attr, + struct switchdev_trans *trans)) +{ + int err; + + err = __switchdev_handle_port_attr_set(dev, port_attr_info, check_cb, + set_cb); + if (err == -EOPNOTSUPP) + err = 0; + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_port_attr_set); diff --git a/net/tipc/link.c b/net/tipc/link.c index 2792a3cae682..341ecd796aa4 100644 --- a/net/tipc/link.c +++ b/net/tipc/link.c @@ -1126,7 +1126,7 @@ static bool tipc_data_input(struct tipc_link *l, struct sk_buff *skb, skb_queue_tail(mc_inputq, skb); return true; } - /* else: fall through */ + /* fall through */ case CONN_MANAGER: skb_queue_tail(inputq, skb); return true; @@ -1145,7 +1145,7 @@ static bool tipc_data_input(struct tipc_link *l, struct sk_buff *skb, default: pr_warn("Dropping received illegal msg type\n"); kfree_skb(skb); - return false; + return true; }; } @@ -1425,6 +1425,10 @@ static void tipc_link_build_proto_msg(struct tipc_link *l, int mtyp, bool probe, l->rcv_unacked = 0; } else { /* RESET_MSG or ACTIVATE_MSG */ + if (mtyp == ACTIVATE_MSG) { + msg_set_dest_session_valid(hdr, 1); + msg_set_dest_session(hdr, l->peer_session); + } msg_set_max_pkt(hdr, l->advertised_mtu); strcpy(data, l->if_name); msg_set_size(hdr, INT_H_SIZE + TIPC_MAX_IF_NAME); @@ -1642,6 +1646,17 @@ static int tipc_link_proto_rcv(struct tipc_link *l, struct sk_buff *skb, rc = tipc_link_fsm_evt(l, LINK_FAILURE_EVT); break; } + + /* If this endpoint was re-created while peer was ESTABLISHING + * it doesn't know current session number. Force re-synch. + */ + if (mtyp == ACTIVATE_MSG && msg_dest_session_valid(hdr) && + l->session != msg_dest_session(hdr)) { + if (less(l->session, msg_dest_session(hdr))) + l->session = msg_dest_session(hdr) + 1; + break; + } + /* ACTIVATE_MSG serves as PEER_RESET if link is already down */ if (mtyp == RESET_MSG || !link_is_up(l)) rc = tipc_link_fsm_evt(l, LINK_PEER_RESET_EVT); diff --git a/net/tipc/msg.h b/net/tipc/msg.h index a0924956bb61..d7e4b8b93f9d 100644 --- a/net/tipc/msg.h +++ b/net/tipc/msg.h @@ -360,6 +360,28 @@ static inline void msg_set_bcast_ack(struct tipc_msg *m, u16 n) msg_set_bits(m, 1, 0, 0xffff, n); } +/* Note: reusing bits in word 1 for ACTIVATE_MSG only, to re-synch + * link peer session number + */ +static inline bool msg_dest_session_valid(struct tipc_msg *m) +{ + return msg_bits(m, 1, 16, 0x1); +} + +static inline void msg_set_dest_session_valid(struct tipc_msg *m, bool valid) +{ + msg_set_bits(m, 1, 16, 0x1, valid); +} + +static inline u16 msg_dest_session(struct tipc_msg *m) +{ + return msg_bits(m, 1, 0, 0xffff); +} + +static inline void msg_set_dest_session(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 1, 0, 0xffff, n); +} /* * Word 2 diff --git a/net/tipc/node.c b/net/tipc/node.c index db2a6c3e0be9..2dc4919ab23c 100644 --- a/net/tipc/node.c +++ b/net/tipc/node.c @@ -830,15 +830,16 @@ static void tipc_node_link_down(struct tipc_node *n, int bearer_id, bool delete) tipc_node_write_lock(n); if (!tipc_link_is_establishing(l)) { __tipc_node_link_down(n, &bearer_id, &xmitq, &maddr); - if (delete) { - kfree(l); - le->link = NULL; - n->link_cnt--; - } } else { /* Defuse pending tipc_node_link_up() */ + tipc_link_reset(l); tipc_link_fsm_evt(l, LINK_RESET_EVT); } + if (delete) { + kfree(l); + le->link = NULL; + n->link_cnt--; + } trace_tipc_node_link_down(n, true, "node link down or deleted!"); tipc_node_write_unlock(n); if (delete) diff --git a/net/tipc/socket.c b/net/tipc/socket.c index 1217c90a363b..3274ef625dba 100644 --- a/net/tipc/socket.c +++ b/net/tipc/socket.c @@ -379,16 +379,18 @@ static int tipc_sk_sock_err(struct socket *sock, long *timeout) #define tipc_wait_for_cond(sock_, timeo_, condition_) \ ({ \ + DEFINE_WAIT_FUNC(wait_, woken_wake_function); \ struct sock *sk_; \ int rc_; \ \ while ((rc_ = !(condition_))) { \ - DEFINE_WAIT_FUNC(wait_, woken_wake_function); \ + /* coupled with smp_wmb() in tipc_sk_proto_rcv() */ \ + smp_rmb(); \ sk_ = (sock_)->sk; \ rc_ = tipc_sk_sock_err((sock_), timeo_); \ if (rc_) \ break; \ - prepare_to_wait(sk_sleep(sk_), &wait_, TASK_INTERRUPTIBLE); \ + add_wait_queue(sk_sleep(sk_), &wait_); \ release_sock(sk_); \ *(timeo_) = wait_woken(&wait_, TASK_INTERRUPTIBLE, *(timeo_)); \ sched_annotate_sleep(); \ @@ -735,7 +737,7 @@ static __poll_t tipc_poll(struct file *file, struct socket *sock, case TIPC_CONNECTING: if (!tsk->cong_link_cnt && !tsk_conn_cong(tsk)) revents |= EPOLLOUT; - /* fall thru' */ + /* fall through */ case TIPC_LISTEN: if (!skb_queue_empty(&sk->sk_receive_queue)) revents |= EPOLLIN | EPOLLRDNORM; @@ -1331,7 +1333,7 @@ static int __tipc_sendmsg(struct socket *sock, struct msghdr *m, size_t dlen) if (unlikely(!dest)) { dest = &tsk->peer; - if (!syn || dest->family != AF_TIPC) + if (!syn && dest->family != AF_TIPC) return -EDESTADDRREQ; } @@ -1677,7 +1679,7 @@ static void tipc_sk_send_ack(struct tipc_sock *tsk) static int tipc_wait_for_rcvmsg(struct socket *sock, long *timeop) { struct sock *sk = sock->sk; - DEFINE_WAIT(wait); + DEFINE_WAIT_FUNC(wait, woken_wake_function); long timeo = *timeop; int err = sock_error(sk); @@ -1685,15 +1687,17 @@ static int tipc_wait_for_rcvmsg(struct socket *sock, long *timeop) return err; for (;;) { - prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); if (timeo && skb_queue_empty(&sk->sk_receive_queue)) { if (sk->sk_shutdown & RCV_SHUTDOWN) { err = -ENOTCONN; break; } + add_wait_queue(sk_sleep(sk), &wait); release_sock(sk); - timeo = schedule_timeout(timeo); + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + sched_annotate_sleep(); lock_sock(sk); + remove_wait_queue(sk_sleep(sk), &wait); } err = 0; if (!skb_queue_empty(&sk->sk_receive_queue)) @@ -1709,7 +1713,6 @@ static int tipc_wait_for_rcvmsg(struct socket *sock, long *timeop) if (err) break; } - finish_wait(sk_sleep(sk), &wait); *timeop = timeo; return err; } @@ -1982,6 +1985,8 @@ static void tipc_sk_proto_rcv(struct sock *sk, return; case SOCK_WAKEUP: tipc_dest_del(&tsk->cong_links, msg_orignode(hdr), 0); + /* coupled with smp_rmb() in tipc_wait_for_cond() */ + smp_wmb(); tsk->cong_link_cnt--; wakeup = true; break; @@ -2416,7 +2421,7 @@ static int tipc_connect(struct socket *sock, struct sockaddr *dest, * case is EINPROGRESS, rather than EALREADY. */ res = -EINPROGRESS; - /* fall thru' */ + /* fall through */ case TIPC_CONNECTING: if (!timeout) { if (previous == TIPC_CONNECTING) diff --git a/net/tipc/topsrv.c b/net/tipc/topsrv.c index a457c0fbbef1..4a708a4e8583 100644 --- a/net/tipc/topsrv.c +++ b/net/tipc/topsrv.c @@ -60,7 +60,6 @@ * @awork: accept work item * @rcv_wq: receive workqueue * @send_wq: send workqueue - * @max_rcvbuf_size: maximum permitted receive message length * @listener: topsrv listener socket * @name: server name */ @@ -72,7 +71,6 @@ struct tipc_topsrv { struct work_struct awork; struct workqueue_struct *rcv_wq; struct workqueue_struct *send_wq; - int max_rcvbuf_size; struct socket *listener; char name[TIPC_SERVER_NAME_LEN]; }; @@ -648,7 +646,6 @@ int tipc_topsrv_start(struct net *net) return -ENOMEM; srv->net = net; - srv->max_rcvbuf_size = sizeof(struct tipc_subscr); INIT_WORK(&srv->awork, tipc_topsrv_accept); strscpy(srv->name, name, sizeof(srv->name)); diff --git a/net/tipc/trace.c b/net/tipc/trace.c index 964823841efe..265f6a26aa3d 100644 --- a/net/tipc/trace.c +++ b/net/tipc/trace.c @@ -111,7 +111,7 @@ int tipc_skb_dump(struct sk_buff *skb, bool more, char *buf) break; default: break; - }; + } i += scnprintf(buf + i, sz - i, " | %u", msg_src_droppable(hdr)); i += scnprintf(buf + i, sz - i, " %u", @@ -122,7 +122,7 @@ int tipc_skb_dump(struct sk_buff *skb, bool more, char *buf) default: /* need more? */ break; - }; + } i += scnprintf(buf + i, sz - i, "\n"); if (!more) diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index d753e362d2d9..135a7ee9db03 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -247,6 +247,7 @@ static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { + struct tls_prot_info *prot = &ctx->prot_info; struct tcp_sock *tp = tcp_sk(sk); struct page_frag dummy_tag_frag; skb_frag_t *frag; @@ -256,21 +257,21 @@ static int tls_push_record(struct sock *sk, frag = &record->frags[0]; tls_fill_prepend(ctx, skb_frag_address(frag), - record->len - ctx->tx.prepend_size, - record_type); + record->len - prot->prepend_size, + record_type, + ctx->crypto_send.info.version); /* HW doesn't care about the data in the tag, because it fills it. */ dummy_tag_frag.page = skb_frag_page(frag); dummy_tag_frag.offset = 0; - tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size); + tls_append_frag(record, &dummy_tag_frag, prot->tag_size); record->end_seq = tp->write_seq + record->len; spin_lock_irq(&offload_ctx->lock); list_add_tail(&record->list, &offload_ctx->records_list); spin_unlock_irq(&offload_ctx->lock); offload_ctx->open_record = NULL; - set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); - tls_advance_record_sn(sk, &ctx->tx); + tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version); for (i = 0; i < record->num_frags; i++) { frag = &record->frags[i]; @@ -346,6 +347,7 @@ static int tls_push_data(struct sock *sk, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); @@ -365,9 +367,11 @@ static int tls_push_data(struct sock *sk, return -sk->sk_err; timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); - rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); - if (rc < 0) - return rc; + if (tls_is_partially_sent_record(tls_ctx)) { + rc = tls_push_partial_record(sk, tls_ctx, flags); + if (rc < 0) + return rc; + } pfrag = sk_page_frag(sk); @@ -375,10 +379,10 @@ static int tls_push_data(struct sock *sk, * we need to leave room for an authentication tag. */ max_open_record_len = TLS_MAX_PAYLOAD_SIZE + - tls_ctx->tx.prepend_size; + prot->prepend_size; do { rc = tls_do_allocation(sk, ctx, pfrag, - tls_ctx->tx.prepend_size); + prot->prepend_size); if (rc) { rc = sk_stream_wait_memory(sk, &timeo); if (!rc) @@ -396,7 +400,7 @@ handle_error: size = orig_size; destroy_record(record); ctx->open_record = NULL; - } else if (record->len > tls_ctx->tx.prepend_size) { + } else if (record->len > prot->prepend_size) { goto last_record; } @@ -542,6 +546,20 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); } +void tls_device_write_space(struct sock *sk, struct tls_context *ctx) +{ + int rc = 0; + + if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) { + gfp_t sk_allocation = sk->sk_allocation; + + sk->sk_allocation = GFP_ATOMIC; + rc = tls_push_partial_record(sk, ctx, + MSG_DONTWAIT | MSG_NOSIGNAL); + sk->sk_allocation = sk_allocation; + } +} + void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) { struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -657,6 +675,8 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) { u16 nonce_size, tag_size, iv_size, rec_seq_size; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_record_info *start_marker_record; struct tls_offload_context_tx *offload_ctx; struct tls_crypto_info *crypto_info; @@ -702,10 +722,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) goto free_offload_ctx; } - ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; - ctx->tx.tag_size = tag_size; - ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; - ctx->tx.iv_size = iv_size; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = tag_size; + prot->overhead_size = prot->prepend_size + prot->tag_size; + prot->iv_size = iv_size; ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL); if (!ctx->tx.iv) { @@ -715,7 +735,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - ctx->tx.rec_seq_size = rec_seq_size; + prot->rec_seq_size = rec_seq_size; ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!ctx->tx.rec_seq) { rc = -ENOMEM; diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 450a6dbc5a88..54c3a758f2a7 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -73,7 +73,8 @@ static int tls_enc_record(struct aead_request *aead_req, len -= TLS_CIPHER_AES_GCM_128_IV_SIZE; tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE, - (char *)&rcd_sn, sizeof(rcd_sn), buf[0]); + (char *)&rcd_sn, sizeof(rcd_sn), buf[0], + TLS_1_2_VERSION); memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE, TLS_CIPHER_AES_GCM_128_IV_SIZE); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 78cb4a584080..df921a2904b9 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -61,6 +61,8 @@ static LIST_HEAD(device_list); static DEFINE_SPINLOCK(device_spinlock); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + struct proto *base); static void update_sk_prot(struct sock *sk, struct tls_context *ctx) { @@ -144,7 +146,6 @@ retry: } ctx->in_tcp_sendpages = false; - ctx->sk_write_space(sk); return 0; } @@ -207,23 +208,9 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } -int tls_push_pending_closed_record(struct sock *sk, - struct tls_context *tls_ctx, - int flags, long *timeo) -{ - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - - if (tls_is_partially_sent_record(tls_ctx) || - !list_empty(&ctx->tx_list)) - return tls_tx_records(sk, flags); - else - return tls_ctx->push_pending_record(sk, flags); -} - static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); /* If in_tcp_sendpages call lower protocol write space handler * to ensure we wake up any waiting operations there. For example @@ -234,12 +221,12 @@ static void tls_write_space(struct sock *sk) return; } - /* Schedule the transmission if tx list is ready */ - if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { - /* Schedule the transmission */ - if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) - schedule_delayed_work(&tx_ctx->tx_work.work, 0); - } +#ifdef CONFIG_TLS_DEVICE + if (ctx->tx_conf == TLS_HW) + tls_device_write_space(sk, ctx); + else +#endif + tls_sw_write_space(sk, ctx); ctx->sk_write_space(sk); } @@ -264,8 +251,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) lock_sock(sk); sk_proto_close = ctx->sk_proto_close; - if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) || - (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) { + if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) + goto skip_tx_cleanup; + + if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { free_ctx = true; goto skip_tx_cleanup; } @@ -368,6 +357,30 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, rc = -EFAULT; break; } + case TLS_CIPHER_AES_GCM_256: { + struct tls12_crypto_info_aes_gcm_256 * + crypto_info_aes_gcm_256 = + container_of(crypto_info, + struct tls12_crypto_info_aes_gcm_256, + info); + + if (len != sizeof(*crypto_info_aes_gcm_256)) { + rc = -EINVAL; + goto out; + } + lock_sock(sk); + memcpy(crypto_info_aes_gcm_256->iv, + ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, + TLS_CIPHER_AES_GCM_256_IV_SIZE); + memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq, + TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); + release_sock(sk); + if (copy_to_user(optval, + crypto_info_aes_gcm_256, + sizeof(*crypto_info_aes_gcm_256))) + rc = -EFAULT; + break; + } default: rc = -EINVAL; } @@ -407,7 +420,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, unsigned int optlen, int tx) { struct tls_crypto_info *crypto_info; + struct tls_crypto_info *alt_crypto_info; struct tls_context *ctx = tls_get_ctx(sk); + size_t optsize; int rc = 0; int conf; @@ -416,10 +431,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto out; } - if (tx) + if (tx) { crypto_info = &ctx->crypto_send.info; - else + alt_crypto_info = &ctx->crypto_recv.info; + } else { crypto_info = &ctx->crypto_recv.info; + alt_crypto_info = &ctx->crypto_send.info; + } /* Currently we don't support set crypto info more than one time */ if (TLS_CRYPTO_INFO_READY(crypto_info)) { @@ -434,14 +452,28 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } /* check version */ - if (crypto_info->version != TLS_1_2_VERSION) { + if (crypto_info->version != TLS_1_2_VERSION && + crypto_info->version != TLS_1_3_VERSION) { rc = -ENOTSUPP; goto err_crypto_info; } + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) { + rc = -EINVAL; + goto err_crypto_info; + } + } + switch (crypto_info->cipher_type) { - case TLS_CIPHER_AES_GCM_128: { - if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { + case TLS_CIPHER_AES_GCM_128: + case TLS_CIPHER_AES_GCM_256: { + optsize = crypto_info->cipher_type == TLS_CIPHER_AES_GCM_128 ? + sizeof(struct tls12_crypto_info_aes_gcm_128) : + sizeof(struct tls12_crypto_info_aes_gcm_256); + if (optlen != optsize) { rc = -EINVAL; goto err_crypto_info; } @@ -551,6 +583,43 @@ static struct tls_context *create_ctx(struct sock *sk) return ctx; } +static void tls_build_proto(struct sock *sk) +{ + int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; + + /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ + if (ip_ver == TLSV6 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { + mutex_lock(&tcpv6_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv6_prot)) { + build_protos(tls_prots[TLSV6], sk->sk_prot); + smp_store_release(&saved_tcpv6_prot, sk->sk_prot); + } + mutex_unlock(&tcpv6_prot_mutex); + } + + if (ip_ver == TLSV4 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { + mutex_lock(&tcpv4_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], sk->sk_prot); + smp_store_release(&saved_tcpv4_prot, sk->sk_prot); + } + mutex_unlock(&tcpv4_prot_mutex); + } +} + +static void tls_hw_sk_destruct(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + + ctx->sk_destruct(sk); + /* Free ctx */ + kfree(ctx); + icsk->icsk_ulp_data = NULL; +} + static int tls_hw_prot(struct sock *sk) { struct tls_context *ctx; @@ -564,12 +633,17 @@ static int tls_hw_prot(struct sock *sk) if (!ctx) goto out; + spin_unlock_bh(&device_spinlock); + tls_build_proto(sk); ctx->hash = sk->sk_prot->hash; ctx->unhash = sk->sk_prot->unhash; ctx->sk_proto_close = sk->sk_prot->close; + ctx->sk_destruct = sk->sk_destruct; + sk->sk_destruct = tls_hw_sk_destruct; ctx->rx_conf = TLS_HW_RECORD; ctx->tx_conf = TLS_HW_RECORD; update_sk_prot(sk, ctx); + spin_lock_bh(&device_spinlock); rc = 1; break; } @@ -668,7 +742,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], static int tls_init(struct sock *sk) { - int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; struct tls_context *ctx; int rc = 0; @@ -691,27 +764,7 @@ static int tls_init(struct sock *sk) goto out; } - /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ - if (ip_ver == TLSV6 && - unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { - mutex_lock(&tcpv6_prot_mutex); - if (likely(sk->sk_prot != saved_tcpv6_prot)) { - build_protos(tls_prots[TLSV6], sk->sk_prot); - smp_store_release(&saved_tcpv6_prot, sk->sk_prot); - } - mutex_unlock(&tcpv6_prot_mutex); - } - - if (ip_ver == TLSV4 && - unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { - mutex_lock(&tcpv4_prot_mutex); - if (likely(sk->sk_prot != saved_tcpv4_prot)) { - build_protos(tls_prots[TLSV4], sk->sk_prot); - smp_store_release(&saved_tcpv4_prot, sk->sk_prot); - } - mutex_unlock(&tcpv4_prot_mutex); - } - + tls_build_proto(sk); ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 11cdc8f7db63..425351ac2a9b 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -120,12 +120,42 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len) return __skb_nsg(skb, offset, len, 0); } +static int padding_length(struct tls_sw_context_rx *ctx, + struct tls_context *tls_ctx, struct sk_buff *skb) +{ + struct strp_msg *rxm = strp_msg(skb); + int sub = 0; + + /* Determine zero-padding length */ + if (tls_ctx->prot_info.version == TLS_1_3_VERSION) { + char content_type = 0; + int err; + int back = 17; + + while (content_type == 0) { + if (back > rxm->full_len) + return -EBADMSG; + err = skb_copy_bits(skb, + rxm->offset + rxm->full_len - back, + &content_type, 1); + if (content_type) + break; + sub++; + back++; + } + ctx->control = content_type; + } + return sub; +} + static void tls_decrypt_done(struct crypto_async_request *req, int err) { struct aead_request *aead_req = (struct aead_request *)req; struct scatterlist *sgout = aead_req->dst; + struct scatterlist *sgin = aead_req->src; struct tls_sw_context_rx *ctx; struct tls_context *tls_ctx; + struct tls_prot_info *prot; struct scatterlist *sg; struct sk_buff *skb; unsigned int pages; @@ -134,12 +164,17 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) skb = (struct sk_buff *)req->data; tls_ctx = tls_get_ctx(skb->sk); ctx = tls_sw_ctx_rx(tls_ctx); - pending = atomic_dec_return(&ctx->decrypt_pending); + prot = &tls_ctx->prot_info; /* Propagate if there was an err */ if (err) { ctx->async_wait.err = err; tls_err_abort(skb->sk, err); + } else { + struct strp_msg *rxm = strp_msg(skb); + rxm->full_len -= padding_length(ctx, tls_ctx, skb); + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; } /* After using skb->sk to propagate sk through crypto async callback @@ -147,18 +182,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) */ skb->sk = NULL; - /* Release the skb, pages and memory allocated for crypto req */ - kfree_skb(skb); - /* Skip the first S/G entry as it points to AAD */ - for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { - if (!sg) - break; - put_page(sg_page(sg)); + /* Free the destination pages if skb was not decrypted inplace */ + if (sgout != sgin) { + /* Skip the first S/G entry as it points to AAD */ + for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { + if (!sg) + break; + put_page(sg_page(sg)); + } } kfree(aead_req); + pending = atomic_dec_return(&ctx->decrypt_pending); + if (!pending && READ_ONCE(ctx->async_notify)) complete(&ctx->async_wait.completion); } @@ -173,13 +211,14 @@ static int tls_do_decryption(struct sock *sk, bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); int ret; aead_request_set_tfm(aead_req, ctx->aead_recv); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_ad(aead_req, prot->aad_size); aead_request_set_crypt(aead_req, sgin, sgout, - data_len + tls_ctx->rx.tag_size, + data_len + prot->tag_size, (u8 *)iv_recv); if (async) { @@ -217,12 +256,13 @@ static int tls_do_decryption(struct sock *sk, static void tls_trim_both_msgs(struct sock *sk, int target_size) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; sk_msg_trim(sk, &rec->msg_plaintext, target_size); if (target_size > 0) - target_size += tls_ctx->tx.overhead_size; + target_size += prot->overhead_size; sk_msg_trim(sk, &rec->msg_encrypted, target_size); } @@ -239,6 +279,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len) static int tls_clone_plaintext_msg(struct sock *sk, int required) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_pl = &rec->msg_plaintext; @@ -254,7 +295,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) /* Skip initial bytes in msg_en's data to be able to use * same offset of both plain and encrypted data. */ - skip = tls_ctx->tx.prepend_size + msg_pl->sg.size; + skip = prot->prepend_size + msg_pl->sg.size; return sk_msg_clone(sk, msg_pl, msg_en, skip, len); } @@ -262,6 +303,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) static struct tls_rec *tls_get_rec(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct sk_msg *msg_pl, *msg_en; struct tls_rec *rec; @@ -280,13 +322,11 @@ static struct tls_rec *tls_get_rec(struct sock *sk) sk_msg_init(msg_en); sg_init_table(rec->sg_aead_in, 2); - sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, - sizeof(rec->aad_space)); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); sg_unmark_end(&rec->sg_aead_in[1]); sg_init_table(rec->sg_aead_out, 2); - sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, - sizeof(rec->aad_space)); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); sg_unmark_end(&rec->sg_aead_out[1]); return rec; @@ -375,6 +415,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) struct aead_request *aead_req = (struct aead_request *)req; struct sock *sk = req->data; struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct scatterlist *sge; struct sk_msg *msg_en; @@ -386,8 +427,8 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) msg_en = &rec->msg_encrypted; sge = sk_msg_elem(msg_en, msg_en->sg.curr); - sge->offset -= tls_ctx->tx.prepend_size; - sge->length += tls_ctx->tx.prepend_size; + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; /* Check if error is previously set on socket */ if (err || sk->sk_err) { @@ -434,21 +475,26 @@ static int tls_do_encryption(struct sock *sk, struct aead_request *aead_req, size_t data_len, u32 start) { + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_rec *rec = ctx->open_rec; struct sk_msg *msg_en = &rec->msg_encrypted; struct scatterlist *sge = sk_msg_elem(msg_en, start); int rc; - sge->offset += tls_ctx->tx.prepend_size; - sge->length -= tls_ctx->tx.prepend_size; + memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data)); + xor_iv_with_seq(prot->version, rec->iv_data, + tls_ctx->tx.rec_seq); + + sge->offset += prot->prepend_size; + sge->length -= prot->prepend_size; msg_en->sg.curr = start; aead_request_set_tfm(aead_req, ctx->aead_send); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_ad(aead_req, prot->aad_size); aead_request_set_crypt(aead_req, rec->sg_aead_in, rec->sg_aead_out, - data_len, tls_ctx->tx.iv); + data_len, rec->iv_data); aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, tls_encrypt_done, sk); @@ -460,8 +506,8 @@ static int tls_do_encryption(struct sock *sk, rc = crypto_aead_encrypt(aead_req); if (!rc || rc != -EINPROGRESS) { atomic_dec(&ctx->encrypt_pending); - sge->offset -= tls_ctx->tx.prepend_size; - sge->length += tls_ctx->tx.prepend_size; + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; } if (!rc) { @@ -473,7 +519,7 @@ static int tls_do_encryption(struct sock *sk, /* Unhook the record from context if encryption is not failure */ ctx->open_rec = NULL; - tls_advance_record_sn(sk, &tls_ctx->tx); + tls_advance_record_sn(sk, &tls_ctx->tx, prot->version); return rc; } @@ -599,6 +645,7 @@ static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec, *tmp = NULL; u32 i, split_point, uninitialized_var(orig_end); @@ -617,12 +664,12 @@ static int tls_push_record(struct sock *sk, int flags, split = split_point && split_point < msg_pl->sg.size; if (split) { rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, - split_point, tls_ctx->tx.overhead_size, + split_point, prot->overhead_size, &orig_end); if (rc < 0) return rc; sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + prot->overhead_size); } rec->tx_flags = flags; @@ -630,7 +677,17 @@ static int tls_push_record(struct sock *sk, int flags, i = msg_pl->sg.end; sk_msg_iter_var_prev(i); - sg_mark_end(sk_msg_elem(msg_pl, i)); + + rec->content_type = record_type; + if (prot->version == TLS_1_3_VERSION) { + /* Add content type to end of message. No padding added */ + sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); + sg_mark_end(&rec->sg_content_type); + sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1, + &rec->sg_content_type); + } else { + sg_mark_end(sk_msg_elem(msg_pl, i)); + } i = msg_pl->sg.start; sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ? @@ -643,18 +700,20 @@ static int tls_push_record(struct sock *sk, int flags, i = msg_en->sg.start; sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); - tls_make_aad(rec->aad_space, msg_pl->sg.size, - tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, - record_type); + tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, + tls_ctx->tx.rec_seq, prot->rec_seq_size, + record_type, prot->version); tls_fill_prepend(tls_ctx, page_address(sg_page(&msg_en->sg.data[i])) + - msg_en->sg.data[i].offset, msg_pl->sg.size, - record_type); + msg_en->sg.data[i].offset, + msg_pl->sg.size + prot->tail_size, + record_type, prot->version); tls_ctx->pending_open_record_frags = false; - rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i); + rc = tls_do_encryption(sk, tls_ctx, ctx, req, + msg_pl->sg.size + prot->tail_size, i); if (rc < 0) { if (rc != -EINPROGRESS) { tls_err_abort(sk, EBADMSG); @@ -663,12 +722,12 @@ static int tls_push_record(struct sock *sk, int flags, tls_merge_open_record(sk, rec, tmp, orig_end); } } + ctx->async_capable = 1; return rc; } else if (split) { msg_pl = &tmp->msg_plaintext; msg_en = &tmp->msg_encrypted; - sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); tls_ctx->pending_open_record_frags = true; ctx->open_rec = tmp; } @@ -803,9 +862,9 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) { long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send); - bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; + bool async_capable = ctx->async_capable; unsigned char record_type = TLS_RECORD_TYPE_DATA; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool eor = !(msg->msg_flags & MSG_MORE); @@ -870,7 +929,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) } required_size = msg_pl->sg.size + try_to_copy + - tls_ctx->tx.overhead_size; + prot->overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -939,8 +998,8 @@ fallback_to_reg_send: */ try_to_copy -= required_size - msg_pl->sg.size; full_record = true; - sk_msg_trim(sk, msg_en, msg_pl->sg.size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, + msg_pl->sg.size + prot->overhead_size); } if (try_to_copy) { @@ -1020,12 +1079,13 @@ send_end: return copied ? copied : ret; } -int tls_sw_do_sendpage(struct sock *sk, struct page *page, - int offset, size_t size, int flags) +static int tls_sw_do_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) { long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; unsigned char record_type = TLS_RECORD_TYPE_DATA; struct sk_msg *msg_pl; struct tls_rec *rec; @@ -1075,8 +1135,7 @@ int tls_sw_do_sendpage(struct sock *sk, struct page *page, full_record = true; } - required_size = msg_pl->sg.size + copy + - tls_ctx->tx.overhead_size; + required_size = msg_pl->sg.size + copy + prot->overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -1143,16 +1202,6 @@ sendpage_end: return copied ? copied : ret; } -int tls_sw_sendpage_locked(struct sock *sk, struct page *page, - int offset, size_t size, int flags) -{ - if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | - MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) - return -ENOTSUPP; - - return tls_sw_do_sendpage(sk, page, offset, size, flags); -} - int tls_sw_sendpage(struct sock *sk, struct page *page, int offset, size_t size, int flags) { @@ -1281,10 +1330,11 @@ out: static int decrypt_internal(struct sock *sk, struct sk_buff *skb, struct iov_iter *out_iov, struct scatterlist *out_sg, - int *chunk, bool *zc) + int *chunk, bool *zc, bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct strp_msg *rxm = strp_msg(skb); int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; struct aead_request *aead_req; @@ -1292,15 +1342,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, u8 *aad, *iv, *mem = NULL; struct scatterlist *sgin = NULL; struct scatterlist *sgout = NULL; - const int data_len = rxm->full_len - tls_ctx->rx.overhead_size; + const int data_len = rxm->full_len - prot->overhead_size + + prot->tail_size; if (*zc && (out_iov || out_sg)) { if (out_iov) n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; else n_sgout = sg_nents(out_sg); - n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); + n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); } else { n_sgout = 0; *zc = false; @@ -1317,7 +1368,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); mem_size = aead_size + (nsg * sizeof(struct scatterlist)); - mem_size = mem_size + TLS_AAD_SPACE_SIZE; + mem_size = mem_size + prot->aad_size; mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); /* Allocate a single block of memory which contains @@ -1333,29 +1384,35 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, sgin = (struct scatterlist *)(mem + aead_size); sgout = sgin + n_sgin; aad = (u8 *)(sgout + n_sgout); - iv = aad + TLS_AAD_SPACE_SIZE; + iv = aad + prot->aad_size; /* Prepare IV */ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - tls_ctx->rx.iv_size); + prot->iv_size); if (err < 0) { kfree(mem); return err; } - memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + if (prot->version == TLS_1_3_VERSION) + memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv)); + else + memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + + xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq); /* Prepare AAD */ - tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size, - tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, - ctx->control); + tls_make_aad(aad, rxm->full_len - prot->overhead_size + + prot->tail_size, + tls_ctx->rx.rec_seq, prot->rec_seq_size, + ctx->control, prot->version); /* Prepare sgin */ sg_init_table(sgin, n_sgin); - sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(&sgin[0], aad, prot->aad_size); err = skb_to_sgvec(skb, &sgin[1], - rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); + rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); if (err < 0) { kfree(mem); return err; @@ -1364,7 +1421,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, if (n_sgout) { if (out_iov) { sg_init_table(sgout, n_sgout); - sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(&sgout[0], aad, prot->aad_size); *chunk = 0; err = tls_setup_from_iter(sk, out_iov, data_len, @@ -1381,13 +1438,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, fallback_to_reg_recv: sgout = sgin; pages = 0; - *chunk = 0; + *chunk = data_len; *zc = false; } /* Prepare and submit AEAD request */ err = tls_do_decryption(sk, skb, sgin, sgout, iv, - data_len, aead_req, *zc); + data_len, aead_req, async); if (err == -EINPROGRESS) return err; @@ -1400,36 +1457,45 @@ fallback_to_reg_recv: } static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct iov_iter *dest, int *chunk, bool *zc) + struct iov_iter *dest, int *chunk, bool *zc, + bool async) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + int version = prot->version; struct strp_msg *rxm = strp_msg(skb); int err = 0; -#ifdef CONFIG_TLS_DEVICE - err = tls_device_decrypted(sk, skb); - if (err < 0) - return err; -#endif if (!ctx->decrypted) { - err = decrypt_internal(sk, skb, dest, NULL, chunk, zc); - if (err < 0) { - if (err == -EINPROGRESS) - tls_advance_record_sn(sk, &tls_ctx->rx); - +#ifdef CONFIG_TLS_DEVICE + err = tls_device_decrypted(sk, skb); + if (err < 0) return err; +#endif + /* Still not decrypted after tls_device */ + if (!ctx->decrypted) { + err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, + async); + if (err < 0) { + if (err == -EINPROGRESS) + tls_advance_record_sn(sk, &tls_ctx->rx, + version); + + return err; + } } + + rxm->full_len -= padding_length(ctx, tls_ctx, skb); + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; + tls_advance_record_sn(sk, &tls_ctx->rx, version); + ctx->decrypted = true; + ctx->saved_data_ready(sk); } else { *zc = false; } - rxm->offset += tls_ctx->rx.prepend_size; - rxm->full_len -= tls_ctx->rx.overhead_size; - tls_advance_record_sn(sk, &tls_ctx->rx); - ctx->decrypted = true; - ctx->saved_data_ready(sk); - return err; } @@ -1439,7 +1505,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb, bool zc = true; int chunk; - return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc); + return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false); } static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, @@ -1466,6 +1532,115 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, return true; } +/* This function traverses the rx_list in tls receive context to copies the + * decrypted records into the buffer provided by caller zero copy is not + * true. Further, the records are removed from the rx_list if it is not a peek + * case and the record has been consumed completely. + */ +static int process_rx_list(struct tls_sw_context_rx *ctx, + struct msghdr *msg, + u8 *control, + bool *cmsg, + size_t skip, + size_t len, + bool zc, + bool is_peek) +{ + struct sk_buff *skb = skb_peek(&ctx->rx_list); + u8 ctrl = *control; + u8 msgc = *cmsg; + struct tls_msg *tlm; + ssize_t copied = 0; + + /* Set the record type in 'control' if caller didn't pass it */ + if (!ctrl && skb) { + tlm = tls_msg(skb); + ctrl = tlm->control; + } + + while (skip && skb) { + struct strp_msg *rxm = strp_msg(skb); + tlm = tls_msg(skb); + + /* Cannot process a record of different type */ + if (ctrl != tlm->control) + return 0; + + if (skip < rxm->full_len) + break; + + skip = skip - rxm->full_len; + skb = skb_peek_next(skb, &ctx->rx_list); + } + + while (len && skb) { + struct sk_buff *next_skb; + struct strp_msg *rxm = strp_msg(skb); + int chunk = min_t(unsigned int, rxm->full_len - skip, len); + + tlm = tls_msg(skb); + + /* Cannot process a record of different type */ + if (ctrl != tlm->control) + return 0; + + /* Set record type if not already done. For a non-data record, + * do not proceed if record type could not be copied. + */ + if (!msgc) { + int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(ctrl), &ctrl); + msgc = true; + if (ctrl != TLS_RECORD_TYPE_DATA) { + if (cerr || msg->msg_flags & MSG_CTRUNC) + return -EIO; + + *cmsg = msgc; + } + } + + if (!zc || (rxm->full_len - skip) > len) { + int err = skb_copy_datagram_msg(skb, rxm->offset + skip, + msg, chunk); + if (err < 0) + return err; + } + + len = len - chunk; + copied = copied + chunk; + + /* Consume the data from record if it is non-peek case*/ + if (!is_peek) { + rxm->offset = rxm->offset + chunk; + rxm->full_len = rxm->full_len - chunk; + + /* Return if there is unconsumed data in the record */ + if (rxm->full_len - skip) + break; + } + + /* The remaining skip-bytes must lie in 1st record in rx_list. + * So from the 2nd record, 'skip' should be 0. + */ + skip = 0; + + if (msg) + msg->msg_flags |= MSG_EOR; + + next_skb = skb_peek_next(skb, &ctx->rx_list); + + if (!is_peek) { + skb_unlink(skb, &ctx->rx_list); + kfree_skb(skb); + } + + skb = next_skb; + } + + *control = ctrl; + return copied; +} + int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, @@ -1475,15 +1650,19 @@ int tls_sw_recvmsg(struct sock *sk, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct sk_psock *psock; - unsigned char control; + unsigned char control = 0; + ssize_t decrypted = 0; struct strp_msg *rxm; + struct tls_msg *tlm; struct sk_buff *skb; ssize_t copied = 0; bool cmsg = false; int target, err = 0; long timeo; bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); + bool is_peek = flags & MSG_PEEK; int num_async = 0; flags |= nonblock; @@ -1494,12 +1673,31 @@ int tls_sw_recvmsg(struct sock *sk, psock = sk_psock_get(sk); lock_sock(sk); - target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + /* Process pending decrypted records. It must be non-zero-copy */ + err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false, + is_peek); + if (err < 0) { + tls_err_abort(sk, err); + goto end; + } else { + copied = err; + } + + len = len - copied; + if (len) { + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + } else { + goto recv_end; + } + do { + bool retain_skb = false; bool zc = false; - bool async = false; + int to_decrypt; int chunk = 0; + bool async_capable; + bool async = false; skb = tls_wait_data(sk, psock, flags, timeo, &err); if (!skb) { @@ -1508,97 +1706,125 @@ int tls_sw_recvmsg(struct sock *sk, msg, len, flags); if (ret > 0) { - copied += ret; + decrypted += ret; len -= ret; continue; } } goto recv_end; + } else { + tlm = tls_msg(skb); + if (prot->version == TLS_1_3_VERSION) + tlm->control = 0; + else + tlm->control = ctx->control; } rxm = strp_msg(skb); + to_decrypt = rxm->full_len - prot->overhead_size; + + if (to_decrypt <= len && !is_kvec && !is_peek && + ctx->control == TLS_RECORD_TYPE_DATA && + prot->version != TLS_1_3_VERSION) + zc = true; + + /* Do not use async mode if record is non-data */ + if (ctx->control == TLS_RECORD_TYPE_DATA) + async_capable = ctx->async_capable; + else + async_capable = false; + + err = decrypt_skb_update(sk, skb, &msg->msg_iter, + &chunk, &zc, async_capable); + if (err < 0 && err != -EINPROGRESS) { + tls_err_abort(sk, EBADMSG); + goto recv_end; + } + + if (err == -EINPROGRESS) { + async = true; + num_async++; + } else if (prot->version == TLS_1_3_VERSION) { + tlm->control = ctx->control; + } + + /* If the type of records being processed is not known yet, + * set it to record type just dequeued. If it is already known, + * but does not match the record type just dequeued, go to end. + * We always get record type here since for tls1.2, record type + * is known just after record is dequeued from stream parser. + * For tls1.3, we disable async. + */ + + if (!control) + control = tlm->control; + else if (control != tlm->control) + goto recv_end; + if (!cmsg) { int cerr; cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, - sizeof(ctx->control), &ctx->control); + sizeof(control), &control); cmsg = true; - control = ctx->control; - if (ctx->control != TLS_RECORD_TYPE_DATA) { + if (control != TLS_RECORD_TYPE_DATA) { if (cerr || msg->msg_flags & MSG_CTRUNC) { err = -EIO; goto recv_end; } } - } else if (control != ctx->control) { - goto recv_end; } - if (!ctx->decrypted) { - int to_copy = rxm->full_len - tls_ctx->rx.overhead_size; + if (async) + goto pick_next_record; - if (!is_kvec && to_copy <= len && - likely(!(flags & MSG_PEEK))) - zc = true; + if (!zc) { + if (rxm->full_len > len) { + retain_skb = true; + chunk = len; + } else { + chunk = rxm->full_len; + } - err = decrypt_skb_update(sk, skb, &msg->msg_iter, - &chunk, &zc); - if (err < 0 && err != -EINPROGRESS) { - tls_err_abort(sk, EBADMSG); + err = skb_copy_datagram_msg(skb, rxm->offset, + msg, chunk); + if (err < 0) goto recv_end; - } - if (err == -EINPROGRESS) { - async = true; - num_async++; - goto pick_next_record; + if (!is_peek) { + rxm->offset = rxm->offset + chunk; + rxm->full_len = rxm->full_len - chunk; } - - ctx->decrypted = true; } - if (!zc) { - chunk = min_t(unsigned int, rxm->full_len, len); +pick_next_record: + if (chunk > len) + chunk = len; - err = skb_copy_datagram_msg(skb, rxm->offset, msg, - chunk); - if (err < 0) - goto recv_end; + decrypted += chunk; + len -= chunk; + + /* For async or peek case, queue the current skb */ + if (async || is_peek || retain_skb) { + skb_queue_tail(&ctx->rx_list, skb); + skb = NULL; } -pick_next_record: - copied += chunk; - len -= chunk; - if (likely(!(flags & MSG_PEEK))) { - u8 control = ctx->control; - - /* For async, drop current skb reference */ - if (async) - skb = NULL; - - if (tls_sw_advance_skb(sk, skb, chunk)) { - /* Return full control message to - * userspace before trying to parse - * another message type - */ - msg->msg_flags |= MSG_EOR; - if (control != TLS_RECORD_TYPE_DATA) - goto recv_end; - } else { - break; - } - } else { - /* MSG_PEEK right now cannot look beyond current skb - * from strparser, meaning we cannot advance skb here - * and thus unpause strparser since we'd loose original - * one. + if (tls_sw_advance_skb(sk, skb, chunk)) { + /* Return full control message to + * userspace before trying to parse + * another message type */ + msg->msg_flags |= MSG_EOR; + if (ctx->control != TLS_RECORD_TYPE_DATA) + goto recv_end; + } else { break; } /* If we have a new message from strparser, continue now. */ - if (copied >= target && !ctx->recv_pkt) + if (decrypted >= target && !ctx->recv_pkt) break; } while (len); @@ -1612,13 +1838,31 @@ recv_end: /* one of async decrypt failed */ tls_err_abort(sk, err); copied = 0; + decrypted = 0; + goto end; } } else { reinit_completion(&ctx->async_wait.completion); } WRITE_ONCE(ctx->async_notify, false); + + /* Drain records from the rx_list & copy if required */ + if (is_peek || is_kvec) + err = process_rx_list(ctx, msg, &control, &cmsg, copied, + decrypted, false, is_peek); + else + err = process_rx_list(ctx, msg, &control, &cmsg, 0, + decrypted, true, is_peek); + if (err < 0) { + tls_err_abort(sk, err); + copied = 0; + goto end; + } } + copied += decrypted; + +end: release_sock(sk); if (psock) sk_psock_put(sk, psock); @@ -1648,14 +1892,14 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, if (!skb) goto splice_read_end; - /* splice does not support reading control messages */ - if (ctx->control != TLS_RECORD_TYPE_DATA) { - err = -ENOTSUPP; - goto splice_read_end; - } - if (!ctx->decrypted) { - err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc); + err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); + + /* splice does not support reading control messages */ + if (ctx->control != TLS_RECORD_TYPE_DATA) { + err = -ENOTSUPP; + goto splice_read_end; + } if (err < 0) { tls_err_abort(sk, EBADMSG); @@ -1698,6 +1942,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; struct strp_msg *rxm = strp_msg(skb); size_t cipher_overhead; @@ -1705,17 +1950,17 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) int ret; /* Verify that we have a full TLS header, or wait for more data */ - if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) + if (rxm->offset + prot->prepend_size > skb->len) return 0; /* Sanity-check size of on-stack buffer. */ - if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { + if (WARN_ON(prot->prepend_size > sizeof(header))) { ret = -EINVAL; goto read_failure; } /* Linearize header to local buffer */ - ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); + ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); if (ret < 0) goto read_failure; @@ -1724,9 +1969,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) data_len = ((header[4] & 0xFF) | (header[3] << 8)); - cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; + cipher_overhead = prot->tag_size; + if (prot->version != TLS_1_3_VERSION) + cipher_overhead += prot->iv_size; - if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { + if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + + prot->tail_size) { ret = -EMSGSIZE; goto read_failure; } @@ -1735,12 +1983,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) goto read_failure; } - if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) || - header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) { + /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */ + if (header[1] != TLS_1_2_VERSION_MINOR || + header[2] != TLS_1_2_VERSION_MAJOR) { ret = -EINVAL; goto read_failure; } - #ifdef CONFIG_TLS_DEVICE handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, *(u64*)tls_ctx->rx.rec_seq); @@ -1792,7 +2040,9 @@ void tls_sw_free_resources_tx(struct sock *sk) if (atomic_read(&ctx->encrypt_pending)) crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + release_sock(sk); cancel_delayed_work_sync(&ctx->tx_work.work); + lock_sock(sk); /* Tx whatever records we can transmit and abandon the rest */ tls_tx_records(sk, -1); @@ -1842,6 +2092,7 @@ void tls_sw_release_resources_rx(struct sock *sk) if (ctx->aead_recv) { kfree_skb(ctx->recv_pkt); ctx->recv_pkt = NULL; + skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); strp_stop(&ctx->strp); write_lock_bh(&sk->sk_callback_lock); @@ -1881,17 +2132,35 @@ static void tx_work_handler(struct work_struct *work) release_sock(sk); } +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) +{ + struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); + + /* Schedule the transmission if tx list is ready */ + if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, + &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); + } +} + int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; + struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; struct crypto_aead **aead; struct strp_callbacks cb; u16 nonce_size, tag_size, iv_size, rec_seq_size; - char *iv, *rec_seq; + struct crypto_tfm *tfm; + char *iv, *rec_seq, *key, *salt; + size_t keysize; int rc = 0; if (!ctx) { @@ -1937,6 +2206,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) crypto_init_wait(&sw_ctx_rx->async_wait); crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; + skb_queue_head_init(&sw_ctx_rx->rx_list); aead = &sw_ctx_rx->aead_recv; } @@ -1951,6 +2221,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; gcm_128_info = (struct tls12_crypto_info_aes_gcm_128 *)crypto_info; + keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; + key = gcm_128_info->key; + salt = gcm_128_info->salt; + break; + } + case TLS_CIPHER_AES_GCM_256: { + nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; + tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE; + iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; + iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv; + rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE; + rec_seq = + ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq; + gcm_256_info = + (struct tls12_crypto_info_aes_gcm_256 *)crypto_info; + keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; + key = gcm_256_info->key; + salt = gcm_256_info->salt; break; } default: @@ -1964,19 +2252,32 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) goto free_priv; } - cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; - cctx->tag_size = tag_size; - cctx->overhead_size = cctx->prepend_size + cctx->tag_size; - cctx->iv_size = iv_size; + if (crypto_info->version == TLS_1_3_VERSION) { + nonce_size = 0; + prot->aad_size = TLS_HEADER_SIZE; + prot->tail_size = 1; + } else { + prot->aad_size = TLS_AAD_SPACE_SIZE; + prot->tail_size = 0; + } + + prot->version = crypto_info->version; + prot->cipher_type = crypto_info->cipher_type; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = tag_size; + prot->overhead_size = prot->prepend_size + + prot->tag_size + prot->tail_size; + prot->iv_size = iv_size; cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL); if (!cctx->iv) { rc = -ENOMEM; goto free_priv; } - memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + /* Note: 128 & 256 bit salt are the same size */ + memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - cctx->rec_seq_size = rec_seq_size; + prot->rec_seq_size = rec_seq_size; cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!cctx->rec_seq) { rc = -ENOMEM; @@ -1994,16 +2295,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ctx->push_pending_record = tls_sw_push_pending_record; - rc = crypto_aead_setkey(*aead, gcm_128_info->key, - TLS_CIPHER_AES_GCM_128_KEY_SIZE); + rc = crypto_aead_setkey(*aead, key, keysize); + if (rc) goto free_aead; - rc = crypto_aead_setauthsize(*aead, cctx->tag_size); + rc = crypto_aead_setauthsize(*aead, prot->tag_size); if (rc) goto free_aead; if (sw_ctx_rx) { + tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); + + if (crypto_info->version == TLS_1_3_VERSION) + sw_ctx_rx->async_capable = false; + else + sw_ctx_rx->async_capable = + tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; + /* Set up strparser */ memset(&cb, 0, sizeof(cb)); cb.rcv_msg = tls_queue; diff --git a/net/unix/Kconfig b/net/unix/Kconfig index 8b31ab85d050..3b9e450656a4 100644 --- a/net/unix/Kconfig +++ b/net/unix/Kconfig @@ -19,6 +19,11 @@ config UNIX Say Y unless you know what you are doing. +config UNIX_SCM + bool + depends on UNIX + default y + config UNIX_DIAG tristate "UNIX: socket monitoring interface" depends on UNIX diff --git a/net/unix/Makefile b/net/unix/Makefile index ffd0a275c3a7..54e58cc4f945 100644 --- a/net/unix/Makefile +++ b/net/unix/Makefile @@ -10,3 +10,5 @@ unix-$(CONFIG_SYSCTL) += sysctl_net_unix.o obj-$(CONFIG_UNIX_DIAG) += unix_diag.o unix_diag-y := diag.o + +obj-$(CONFIG_UNIX_SCM) += scm.o diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index 74d1eed7cbd4..ddb838a1b74c 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -119,6 +119,8 @@ #include <linux/freezer.h> #include <linux/file.h> +#include "scm.h" + struct hlist_head unix_socket_table[2 * UNIX_HASH_SIZE]; EXPORT_SYMBOL_GPL(unix_socket_table); DEFINE_SPINLOCK(unix_table_lock); @@ -890,7 +892,7 @@ retry: addr->hash ^= sk->sk_type; __unix_remove_socket(sk); - u->addr = addr; + smp_store_release(&u->addr, addr); __unix_insert_socket(&unix_socket_table[addr->hash], sk); spin_unlock(&unix_table_lock); err = 0; @@ -1060,7 +1062,7 @@ static int unix_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) err = 0; __unix_remove_socket(sk); - u->addr = addr; + smp_store_release(&u->addr, addr); __unix_insert_socket(list, sk); out_unlock: @@ -1331,15 +1333,29 @@ restart: RCU_INIT_POINTER(newsk->sk_wq, &newu->peer_wq); otheru = unix_sk(other); - /* copy address information from listening to new sock*/ - if (otheru->addr) { - refcount_inc(&otheru->addr->refcnt); - newu->addr = otheru->addr; - } + /* copy address information from listening to new sock + * + * The contents of *(otheru->addr) and otheru->path + * are seen fully set up here, since we have found + * otheru in hash under unix_table_lock. Insertion + * into the hash chain we'd found it in had been done + * in an earlier critical area protected by unix_table_lock, + * the same one where we'd set *(otheru->addr) contents, + * as well as otheru->path and otheru->addr itself. + * + * Using smp_store_release() here to set newu->addr + * is enough to make those stores, as well as stores + * to newu->path visible to anyone who gets newu->addr + * by smp_load_acquire(). IOW, the same warranties + * as for unix_sock instances bound in unix_bind() or + * in unix_autobind(). + */ if (otheru->path.dentry) { path_get(&otheru->path); newu->path = otheru->path; } + refcount_inc(&otheru->addr->refcnt); + smp_store_release(&newu->addr, otheru->addr); /* Set credentials */ copy_peercred(sk, other); @@ -1453,7 +1469,7 @@ out: static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int peer) { struct sock *sk = sock->sk; - struct unix_sock *u; + struct unix_address *addr; DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, uaddr); int err = 0; @@ -1468,85 +1484,20 @@ static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int peer) sock_hold(sk); } - u = unix_sk(sk); - unix_state_lock(sk); - if (!u->addr) { + addr = smp_load_acquire(&unix_sk(sk)->addr); + if (!addr) { sunaddr->sun_family = AF_UNIX; sunaddr->sun_path[0] = 0; err = sizeof(short); } else { - struct unix_address *addr = u->addr; - err = addr->len; memcpy(sunaddr, addr->name, addr->len); } - unix_state_unlock(sk); sock_put(sk); out: return err; } -static void unix_detach_fds(struct scm_cookie *scm, struct sk_buff *skb) -{ - int i; - - scm->fp = UNIXCB(skb).fp; - UNIXCB(skb).fp = NULL; - - for (i = scm->fp->count-1; i >= 0; i--) - unix_notinflight(scm->fp->user, scm->fp->fp[i]); -} - -static void unix_destruct_scm(struct sk_buff *skb) -{ - struct scm_cookie scm; - memset(&scm, 0, sizeof(scm)); - scm.pid = UNIXCB(skb).pid; - if (UNIXCB(skb).fp) - unix_detach_fds(&scm, skb); - - /* Alas, it calls VFS */ - /* So fscking what? fput() had been SMP-safe since the last Summer */ - scm_destroy(&scm); - sock_wfree(skb); -} - -/* - * The "user->unix_inflight" variable is protected by the garbage - * collection lock, and we just read it locklessly here. If you go - * over the limit, there might be a tiny race in actually noticing - * it across threads. Tough. - */ -static inline bool too_many_unix_fds(struct task_struct *p) -{ - struct user_struct *user = current_user(); - - if (unlikely(user->unix_inflight > task_rlimit(p, RLIMIT_NOFILE))) - return !capable(CAP_SYS_RESOURCE) && !capable(CAP_SYS_ADMIN); - return false; -} - -static int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb) -{ - int i; - - if (too_many_unix_fds(current)) - return -ETOOMANYREFS; - - /* - * Need to duplicate file references for the sake of garbage - * collection. Otherwise a socket in the fps might become a - * candidate for GC while the skb is not yet queued. - */ - UNIXCB(skb).fp = scm_fp_dup(scm->fp); - if (!UNIXCB(skb).fp) - return -ENOMEM; - - for (i = scm->fp->count - 1; i >= 0; i--) - unix_inflight(scm->fp->user, scm->fp->fp[i]); - return 0; -} - static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool send_fds) { int err = 0; @@ -2073,11 +2024,11 @@ static int unix_seqpacket_recvmsg(struct socket *sock, struct msghdr *msg, static void unix_copy_addr(struct msghdr *msg, struct sock *sk) { - struct unix_sock *u = unix_sk(sk); + struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr); - if (u->addr) { - msg->msg_namelen = u->addr->len; - memcpy(msg->msg_name, u->addr->name, u->addr->len); + if (addr) { + msg->msg_namelen = addr->len; + memcpy(msg->msg_name, addr->name, addr->len); } } @@ -2581,15 +2532,14 @@ static int unix_open_file(struct sock *sk) if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) return -EPERM; - unix_state_lock(sk); + if (!smp_load_acquire(&unix_sk(sk)->addr)) + return -ENOENT; + path = unix_sk(sk)->path; - if (!path.dentry) { - unix_state_unlock(sk); + if (!path.dentry) return -ENOENT; - } path_get(&path); - unix_state_unlock(sk); fd = get_unused_fd_flags(O_CLOEXEC); if (fd < 0) @@ -2830,7 +2780,7 @@ static int unix_seq_show(struct seq_file *seq, void *v) (s->sk_state == TCP_ESTABLISHED ? SS_CONNECTING : SS_DISCONNECTING), sock_i_ino(s)); - if (u->addr) { + if (u->addr) { // under unix_table_lock here int i, len; seq_putc(seq, ' '); diff --git a/net/unix/diag.c b/net/unix/diag.c index 384c84e83462..3183d9b8ab33 100644 --- a/net/unix/diag.c +++ b/net/unix/diag.c @@ -10,7 +10,8 @@ static int sk_diag_dump_name(struct sock *sk, struct sk_buff *nlskb) { - struct unix_address *addr = unix_sk(sk)->addr; + /* might or might not have unix_table_lock */ + struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr); if (!addr) return 0; diff --git a/net/unix/garbage.c b/net/unix/garbage.c index c36757e72844..8bbe1b8e4ff7 100644 --- a/net/unix/garbage.c +++ b/net/unix/garbage.c @@ -86,77 +86,13 @@ #include <net/scm.h> #include <net/tcp_states.h> +#include "scm.h" + /* Internal data structures and random procedures: */ -static LIST_HEAD(gc_inflight_list); static LIST_HEAD(gc_candidates); -static DEFINE_SPINLOCK(unix_gc_lock); static DECLARE_WAIT_QUEUE_HEAD(unix_gc_wait); -unsigned int unix_tot_inflight; - -struct sock *unix_get_socket(struct file *filp) -{ - struct sock *u_sock = NULL; - struct inode *inode = file_inode(filp); - - /* Socket ? */ - if (S_ISSOCK(inode->i_mode) && !(filp->f_mode & FMODE_PATH)) { - struct socket *sock = SOCKET_I(inode); - struct sock *s = sock->sk; - - /* PF_UNIX ? */ - if (s && sock->ops && sock->ops->family == PF_UNIX) - u_sock = s; - } - return u_sock; -} - -/* Keep the number of times in flight count for the file - * descriptor if it is for an AF_UNIX socket. - */ - -void unix_inflight(struct user_struct *user, struct file *fp) -{ - struct sock *s = unix_get_socket(fp); - - spin_lock(&unix_gc_lock); - - if (s) { - struct unix_sock *u = unix_sk(s); - - if (atomic_long_inc_return(&u->inflight) == 1) { - BUG_ON(!list_empty(&u->link)); - list_add_tail(&u->link, &gc_inflight_list); - } else { - BUG_ON(list_empty(&u->link)); - } - unix_tot_inflight++; - } - user->unix_inflight++; - spin_unlock(&unix_gc_lock); -} - -void unix_notinflight(struct user_struct *user, struct file *fp) -{ - struct sock *s = unix_get_socket(fp); - - spin_lock(&unix_gc_lock); - - if (s) { - struct unix_sock *u = unix_sk(s); - - BUG_ON(!atomic_long_read(&u->inflight)); - BUG_ON(list_empty(&u->link)); - - if (atomic_long_dec_and_test(&u->inflight)) - list_del_init(&u->link); - unix_tot_inflight--; - } - user->unix_inflight--; - spin_unlock(&unix_gc_lock); -} - static void scan_inflight(struct sock *x, void (*func)(struct unix_sock *), struct sk_buff_head *hitlist) { diff --git a/net/unix/scm.c b/net/unix/scm.c new file mode 100644 index 000000000000..8c40f2b32392 --- /dev/null +++ b/net/unix/scm.c @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <linux/module.h> +#include <linux/kernel.h> +#include <linux/string.h> +#include <linux/socket.h> +#include <linux/net.h> +#include <linux/fs.h> +#include <net/af_unix.h> +#include <net/scm.h> +#include <linux/init.h> + +#include "scm.h" + +unsigned int unix_tot_inflight; +EXPORT_SYMBOL(unix_tot_inflight); + +LIST_HEAD(gc_inflight_list); +EXPORT_SYMBOL(gc_inflight_list); + +DEFINE_SPINLOCK(unix_gc_lock); +EXPORT_SYMBOL(unix_gc_lock); + +struct sock *unix_get_socket(struct file *filp) +{ + struct sock *u_sock = NULL; + struct inode *inode = file_inode(filp); + + /* Socket ? */ + if (S_ISSOCK(inode->i_mode) && !(filp->f_mode & FMODE_PATH)) { + struct socket *sock = SOCKET_I(inode); + struct sock *s = sock->sk; + + /* PF_UNIX ? */ + if (s && sock->ops && sock->ops->family == PF_UNIX) + u_sock = s; + } else { + /* Could be an io_uring instance */ + u_sock = io_uring_get_socket(filp); + } + return u_sock; +} +EXPORT_SYMBOL(unix_get_socket); + +/* Keep the number of times in flight count for the file + * descriptor if it is for an AF_UNIX socket. + */ +void unix_inflight(struct user_struct *user, struct file *fp) +{ + struct sock *s = unix_get_socket(fp); + + spin_lock(&unix_gc_lock); + + if (s) { + struct unix_sock *u = unix_sk(s); + + if (atomic_long_inc_return(&u->inflight) == 1) { + BUG_ON(!list_empty(&u->link)); + list_add_tail(&u->link, &gc_inflight_list); + } else { + BUG_ON(list_empty(&u->link)); + } + unix_tot_inflight++; + } + user->unix_inflight++; + spin_unlock(&unix_gc_lock); +} + +void unix_notinflight(struct user_struct *user, struct file *fp) +{ + struct sock *s = unix_get_socket(fp); + + spin_lock(&unix_gc_lock); + + if (s) { + struct unix_sock *u = unix_sk(s); + + BUG_ON(!atomic_long_read(&u->inflight)); + BUG_ON(list_empty(&u->link)); + + if (atomic_long_dec_and_test(&u->inflight)) + list_del_init(&u->link); + unix_tot_inflight--; + } + user->unix_inflight--; + spin_unlock(&unix_gc_lock); +} + +/* + * The "user->unix_inflight" variable is protected by the garbage + * collection lock, and we just read it locklessly here. If you go + * over the limit, there might be a tiny race in actually noticing + * it across threads. Tough. + */ +static inline bool too_many_unix_fds(struct task_struct *p) +{ + struct user_struct *user = current_user(); + + if (unlikely(user->unix_inflight > task_rlimit(p, RLIMIT_NOFILE))) + return !capable(CAP_SYS_RESOURCE) && !capable(CAP_SYS_ADMIN); + return false; +} + +int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb) +{ + int i; + + if (too_many_unix_fds(current)) + return -ETOOMANYREFS; + + /* + * Need to duplicate file references for the sake of garbage + * collection. Otherwise a socket in the fps might become a + * candidate for GC while the skb is not yet queued. + */ + UNIXCB(skb).fp = scm_fp_dup(scm->fp); + if (!UNIXCB(skb).fp) + return -ENOMEM; + + for (i = scm->fp->count - 1; i >= 0; i--) + unix_inflight(scm->fp->user, scm->fp->fp[i]); + return 0; +} +EXPORT_SYMBOL(unix_attach_fds); + +void unix_detach_fds(struct scm_cookie *scm, struct sk_buff *skb) +{ + int i; + + scm->fp = UNIXCB(skb).fp; + UNIXCB(skb).fp = NULL; + + for (i = scm->fp->count-1; i >= 0; i--) + unix_notinflight(scm->fp->user, scm->fp->fp[i]); +} +EXPORT_SYMBOL(unix_detach_fds); + +void unix_destruct_scm(struct sk_buff *skb) +{ + struct scm_cookie scm; + + memset(&scm, 0, sizeof(scm)); + scm.pid = UNIXCB(skb).pid; + if (UNIXCB(skb).fp) + unix_detach_fds(&scm, skb); + + /* Alas, it calls VFS */ + /* So fscking what? fput() had been SMP-safe since the last Summer */ + scm_destroy(&scm); + sock_wfree(skb); +} +EXPORT_SYMBOL(unix_destruct_scm); diff --git a/net/unix/scm.h b/net/unix/scm.h new file mode 100644 index 000000000000..5a255a477f16 --- /dev/null +++ b/net/unix/scm.h @@ -0,0 +1,10 @@ +#ifndef NET_UNIX_SCM_H +#define NET_UNIX_SCM_H + +extern struct list_head gc_inflight_list; +extern spinlock_t unix_gc_lock; + +int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb); +void unix_detach_fds(struct scm_cookie *scm, struct sk_buff *skb); + +#endif diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 43a1dec08825..d892000770cf 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -505,7 +505,7 @@ out: static int __vsock_bind_stream(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - static u32 port = 0; + static u32 port; struct sockaddr_vm new_addr; if (!port) @@ -1439,7 +1439,7 @@ static int vsock_stream_setsockopt(struct socket *sock, break; case SO_VM_SOCKETS_CONNECT_TIMEOUT: { - struct timeval tv; + struct __kernel_old_timeval tv; COPY_IN(tv); if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC && tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) { @@ -1517,7 +1517,7 @@ static int vsock_stream_getsockopt(struct socket *sock, break; case SO_VM_SOCKETS_CONNECT_TIMEOUT: { - struct timeval tv; + struct __kernel_old_timeval tv; tv.tv_sec = vsk->connect_timeout / HZ; tv.tv_usec = (vsk->connect_timeout - diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 5d3cce9e8744..15eb5d3d4750 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -75,6 +75,9 @@ static u32 virtio_transport_get_local_cid(void) { struct virtio_vsock *vsock = virtio_vsock_get(); + if (!vsock) + return VMADDR_CID_ANY; + return vsock->guest_cid; } @@ -584,10 +587,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev) virtio_vsock_update_guest_cid(vsock); - ret = vsock_core_init(&virtio_transport.transport); - if (ret < 0) - goto out_vqs; - vsock->rx_buf_nr = 0; vsock->rx_buf_max_nr = 0; atomic_set(&vsock->queued_replies, 0); @@ -618,8 +617,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev) mutex_unlock(&the_virtio_vsock_mutex); return 0; -out_vqs: - vsock->vdev->config->del_vqs(vsock->vdev); out: kfree(vsock); mutex_unlock(&the_virtio_vsock_mutex); @@ -637,6 +634,9 @@ static void virtio_vsock_remove(struct virtio_device *vdev) flush_work(&vsock->event_work); flush_work(&vsock->send_pkt_work); + /* Reset all connected sockets when the device disappear */ + vsock_for_each_connected_socket(virtio_vsock_reset_sock); + vdev->config->reset(vdev); mutex_lock(&vsock->rx_lock); @@ -669,7 +669,6 @@ static void virtio_vsock_remove(struct virtio_device *vdev) mutex_lock(&the_virtio_vsock_mutex); the_virtio_vsock = NULL; - vsock_core_exit(); mutex_unlock(&the_virtio_vsock_mutex); vdev->config->del_vqs(vdev); @@ -702,14 +701,28 @@ static int __init virtio_vsock_init(void) virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); if (!virtio_vsock_workqueue) return -ENOMEM; + ret = register_virtio_driver(&virtio_vsock_driver); if (ret) - destroy_workqueue(virtio_vsock_workqueue); + goto out_wq; + + ret = vsock_core_init(&virtio_transport.transport); + if (ret) + goto out_vdr; + + return 0; + +out_vdr: + unregister_virtio_driver(&virtio_vsock_driver); +out_wq: + destroy_workqueue(virtio_vsock_workqueue); return ret; + } static void __exit virtio_vsock_exit(void) { + vsock_core_exit(); unregister_virtio_driver(&virtio_vsock_driver); destroy_workqueue(virtio_vsock_workqueue); } diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 3ae3a33da70b..602715fc9a75 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -662,6 +662,8 @@ static int virtio_transport_reset(struct vsock_sock *vsk, */ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) { + const struct virtio_transport *t; + struct virtio_vsock_pkt *reply; struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RST, .type = le16_to_cpu(pkt->hdr.type), @@ -672,15 +674,21 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) return 0; - pkt = virtio_transport_alloc_pkt(&info, 0, - le64_to_cpu(pkt->hdr.dst_cid), - le32_to_cpu(pkt->hdr.dst_port), - le64_to_cpu(pkt->hdr.src_cid), - le32_to_cpu(pkt->hdr.src_port)); - if (!pkt) + reply = virtio_transport_alloc_pkt(&info, 0, + le64_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.dst_port), + le64_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.src_port)); + if (!reply) return -ENOMEM; - return virtio_transport_get_ops()->send_pkt(pkt); + t = virtio_transport_get_ops(); + if (!t) { + virtio_transport_free_pkt(reply); + return -ENOTCONN; + } + + return t->send_pkt(reply); } static void virtio_transport_wait_close(struct sock *sk, long timeout) diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index c361ce782412..c3d5ab01fba7 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -1651,6 +1651,10 @@ static void vmci_transport_cleanup(struct work_struct *work) static void vmci_transport_destruct(struct vsock_sock *vsk) { + /* transport can be NULL if we hit a failure at init() time */ + if (!vmci_trans(vsk)) + return; + /* Ensure that the detach callback doesn't use the sk/vsk * we are about to destruct. */ diff --git a/net/wireless/ap.c b/net/wireless/ap.c index 882d97bdc6bf..550ac9d827fe 100644 --- a/net/wireless/ap.c +++ b/net/wireless/ap.c @@ -41,6 +41,8 @@ int __cfg80211_stop_ap(struct cfg80211_registered_device *rdev, cfg80211_sched_dfs_chan_update(rdev); } + schedule_work(&cfg80211_disconnect_work); + return err; } diff --git a/net/wireless/core.c b/net/wireless/core.c index 623dfe5e211c..b36ad8efb5e5 100644 --- a/net/wireless/core.c +++ b/net/wireless/core.c @@ -1068,6 +1068,8 @@ static void __cfg80211_unregister_wdev(struct wireless_dev *wdev, bool sync) ASSERT_RTNL(); + flush_work(&wdev->pmsr_free_wk); + nl80211_notify_iface(rdev, wdev, NL80211_CMD_DEL_INTERFACE); list_del_rcu(&wdev->list); diff --git a/net/wireless/core.h b/net/wireless/core.h index c5d6f3418601..84d36ca7a7ab 100644 --- a/net/wireless/core.h +++ b/net/wireless/core.h @@ -3,7 +3,7 @@ * Wireless configuration interface internals. * * Copyright 2006-2010 Johannes Berg <johannes@sipsolutions.net> - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018-2019 Intel Corporation */ #ifndef __NET_WIRELESS_CORE_H #define __NET_WIRELESS_CORE_H @@ -182,12 +182,23 @@ static inline struct cfg80211_internal_bss *bss_from_pub(struct cfg80211_bss *pu static inline void cfg80211_hold_bss(struct cfg80211_internal_bss *bss) { atomic_inc(&bss->hold); + if (bss->pub.transmitted_bss) { + bss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, pub); + atomic_inc(&bss->hold); + } } static inline void cfg80211_unhold_bss(struct cfg80211_internal_bss *bss) { int r = atomic_dec_return(&bss->hold); WARN_ON(r < 0); + if (bss->pub.transmitted_bss) { + bss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, pub); + r = atomic_dec_return(&bss->hold); + WARN_ON(r < 0); + } } @@ -445,6 +456,8 @@ void cfg80211_process_wdev_events(struct wireless_dev *wdev); bool cfg80211_does_bw_fit_range(const struct ieee80211_freq_range *freq_range, u32 center_freq_khz, u32 bw_khz); +extern struct work_struct cfg80211_disconnect_work; + /** * cfg80211_chandef_dfs_usable - checks if chandef is DFS usable * @wiphy: the wiphy to validate against diff --git a/net/wireless/mlme.c b/net/wireless/mlme.c index 1615e503f8e3..f9462010575f 100644 --- a/net/wireless/mlme.c +++ b/net/wireless/mlme.c @@ -21,7 +21,8 @@ void cfg80211_rx_assoc_resp(struct net_device *dev, struct cfg80211_bss *bss, - const u8 *buf, size_t len, int uapsd_queues) + const u8 *buf, size_t len, int uapsd_queues, + const u8 *req_ies, size_t req_ies_len) { struct wireless_dev *wdev = dev->ieee80211_ptr; struct wiphy *wiphy = wdev->wiphy; @@ -33,6 +34,8 @@ void cfg80211_rx_assoc_resp(struct net_device *dev, struct cfg80211_bss *bss, cr.status = (int)le16_to_cpu(mgmt->u.assoc_resp.status_code); cr.bssid = mgmt->bssid; cr.bss = bss; + cr.req_ie = req_ies; + cr.req_ie_len = req_ies_len; cr.resp_ie = mgmt->u.assoc_resp.variable; cr.resp_ie_len = len - offsetof(struct ieee80211_mgmt, u.assoc_resp.variable); @@ -52,7 +55,8 @@ void cfg80211_rx_assoc_resp(struct net_device *dev, struct cfg80211_bss *bss, return; } - nl80211_send_rx_assoc(rdev, dev, buf, len, GFP_KERNEL, uapsd_queues); + nl80211_send_rx_assoc(rdev, dev, buf, len, GFP_KERNEL, uapsd_queues, + req_ies, req_ies_len); /* update current_bss etc., consumes the bss reference */ __cfg80211_connect_result(dev, &cr, cr.status == WLAN_STATUS_SUCCESS); } diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c index 5e49492d5911..25a9e3b5c154 100644 --- a/net/wireless/nl80211.c +++ b/net/wireless/nl80211.c @@ -4,7 +4,7 @@ * Copyright 2006-2010 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright 2015-2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018-2019 Intel Corporation */ #include <linux/if.h> @@ -203,29 +203,17 @@ cfg80211_get_dev_from_info(struct net *netns, struct genl_info *info) static int validate_ie_attr(const struct nlattr *attr, struct netlink_ext_ack *extack) { - const u8 *pos; - int len; + const u8 *data = nla_data(attr); + unsigned int len = nla_len(attr); + const struct element *elem; - pos = nla_data(attr); - len = nla_len(attr); - - while (len) { - u8 elemlen; - - if (len < 2) - goto error; - len -= 2; - - elemlen = pos[1]; - if (elemlen > len) - goto error; - - len -= elemlen; - pos += 2 + elemlen; + for_each_element(elem, data, len) { + /* nothing */ } - return 0; -error: + if (for_each_element_completed(elem, data, len)) + return 0; + NL_SET_ERR_MSG_ATTR(extack, attr, "malformed information elements"); return -EINVAL; } @@ -250,7 +238,7 @@ nl80211_pmsr_ftm_req_attr_policy[NL80211_PMSR_FTM_REQ_ATTR_MAX + 1] = { [NL80211_PMSR_FTM_REQ_ATTR_BURST_DURATION] = NLA_POLICY_MAX(NLA_U8, 15), [NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST] = - NLA_POLICY_MAX(NLA_U8, 15), + NLA_POLICY_MAX(NLA_U8, 31), [NL80211_PMSR_FTM_REQ_ATTR_NUM_FTMR_RETRIES] = { .type = NLA_U8 }, [NL80211_PMSR_FTM_REQ_ATTR_REQUEST_LCI] = { .type = NLA_FLAG }, [NL80211_PMSR_FTM_REQ_ATTR_REQUEST_CIVICLOC] = { .type = NLA_FLAG }, @@ -259,15 +247,13 @@ nl80211_pmsr_ftm_req_attr_policy[NL80211_PMSR_FTM_REQ_ATTR_MAX + 1] = { static const struct nla_policy nl80211_pmsr_req_data_policy[NL80211_PMSR_TYPE_MAX + 1] = { [NL80211_PMSR_TYPE_FTM] = - NLA_POLICY_NESTED(NL80211_PMSR_FTM_REQ_ATTR_MAX, - nl80211_pmsr_ftm_req_attr_policy), + NLA_POLICY_NESTED(nl80211_pmsr_ftm_req_attr_policy), }; static const struct nla_policy nl80211_pmsr_req_attr_policy[NL80211_PMSR_REQ_ATTR_MAX + 1] = { [NL80211_PMSR_REQ_ATTR_DATA] = - NLA_POLICY_NESTED(NL80211_PMSR_TYPE_MAX, - nl80211_pmsr_req_data_policy), + NLA_POLICY_NESTED(nl80211_pmsr_req_data_policy), [NL80211_PMSR_REQ_ATTR_GET_AP_TSF] = { .type = NLA_FLAG }, }; @@ -280,8 +266,7 @@ nl80211_psmr_peer_attr_policy[NL80211_PMSR_PEER_ATTR_MAX + 1] = { */ [NL80211_PMSR_PEER_ATTR_CHAN] = { .type = NLA_NESTED }, [NL80211_PMSR_PEER_ATTR_REQ] = - NLA_POLICY_NESTED(NL80211_PMSR_REQ_ATTR_MAX, - nl80211_pmsr_req_attr_policy), + NLA_POLICY_NESTED(nl80211_pmsr_req_attr_policy), [NL80211_PMSR_PEER_ATTR_RESP] = { .type = NLA_REJECT }, }; @@ -292,8 +277,7 @@ nl80211_pmsr_attr_policy[NL80211_PMSR_ATTR_MAX + 1] = { [NL80211_PMSR_ATTR_RANDOMIZE_MAC_ADDR] = { .type = NLA_REJECT }, [NL80211_PMSR_ATTR_TYPE_CAPA] = { .type = NLA_REJECT }, [NL80211_PMSR_ATTR_PEERS] = - NLA_POLICY_NESTED_ARRAY(NL80211_PMSR_PEER_ATTR_MAX, - nl80211_psmr_peer_attr_policy), + NLA_POLICY_NESTED_ARRAY(nl80211_psmr_peer_attr_policy), }; const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { @@ -555,8 +539,8 @@ const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { }, [NL80211_ATTR_TIMEOUT] = NLA_POLICY_MIN(NLA_U32, 1), [NL80211_ATTR_PEER_MEASUREMENTS] = - NLA_POLICY_NESTED(NL80211_PMSR_FTM_REQ_ATTR_MAX, - nl80211_pmsr_attr_policy), + NLA_POLICY_NESTED(nl80211_pmsr_attr_policy), + [NL80211_ATTR_AIRTIME_WEIGHT] = NLA_POLICY_MIN(NLA_U16, 1), }; /* policy for the key attributes */ @@ -2278,6 +2262,15 @@ static int nl80211_send_wiphy(struct cfg80211_registered_device *rdev, if (nl80211_send_pmsr_capa(rdev, msg)) goto nla_put_failure; + state->split_start++; + break; + case 15: + if (rdev->wiphy.akm_suites && + nla_put(msg, NL80211_ATTR_AKM_SUITES, + sizeof(u32) * rdev->wiphy.n_akm_suites, + rdev->wiphy.akm_suites)) + goto nla_put_failure; + /* done */ state->split_start = 0; break; @@ -4540,6 +4533,9 @@ static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info) nl80211_calculate_ap_params(¶ms); + if (info->attrs[NL80211_ATTR_EXTERNAL_AUTH_SUPPORT]) + params.flags |= AP_SETTINGS_EXTERNAL_AUTH_SUPPORT; + wdev_lock(wdev); err = rdev_start_ap(rdev, dev, ¶ms); if (!err) { @@ -4851,6 +4847,11 @@ static int nl80211_send_station(struct sk_buff *msg, u32 cmd, u32 portid, PUT_SINFO(PLID, plid, u16); PUT_SINFO(PLINK_STATE, plink_state, u8); PUT_SINFO_U64(RX_DURATION, rx_duration); + PUT_SINFO_U64(TX_DURATION, tx_duration); + + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + PUT_SINFO(AIRTIME_WEIGHT, airtime_weight, u16); switch (rdev->wiphy.signal_type) { case CFG80211_SIGNAL_TYPE_MBM: @@ -5470,6 +5471,15 @@ static int nl80211_set_station(struct sk_buff *skb, struct genl_info *info) nla_get_u8(info->attrs[NL80211_ATTR_OPMODE_NOTIF]); } + if (info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]) + params.airtime_weight = + nla_get_u16(info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]); + + if (params.airtime_weight && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + return -EOPNOTSUPP; + /* Include parameters for TDLS peer (will check later) */ err = nl80211_set_station_tdls(info, ¶ms); if (err) @@ -5598,6 +5608,15 @@ static int nl80211_new_station(struct sk_buff *skb, struct genl_info *info) params.plink_action = nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_ACTION]); + if (info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]) + params.airtime_weight = + nla_get_u16(info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]); + + if (params.airtime_weight && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + return -EOPNOTSUPP; + err = nl80211_parse_sta_channel_info(info, ¶ms); if (err) return err; @@ -5803,7 +5822,13 @@ static int nl80211_send_mpath(struct sk_buff *msg, u32 portid, u32 seq, pinfo->discovery_timeout)) || ((pinfo->filled & MPATH_INFO_DISCOVERY_RETRIES) && nla_put_u8(msg, NL80211_MPATH_INFO_DISCOVERY_RETRIES, - pinfo->discovery_retries))) + pinfo->discovery_retries)) || + ((pinfo->filled & MPATH_INFO_HOP_COUNT) && + nla_put_u8(msg, NL80211_MPATH_INFO_HOP_COUNT, + pinfo->hop_count)) || + ((pinfo->filled & MPATH_INFO_PATH_CHANGE) && + nla_put_u32(msg, NL80211_MPATH_INFO_PATH_CHANGE, + pinfo->path_change_count))) goto nla_put_failure; nla_nest_end(msg, pinfoattr); @@ -9281,6 +9306,7 @@ struct sk_buff *__cfg80211_alloc_event_skb(struct wiphy *wiphy, struct wireless_dev *wdev, enum nl80211_commands cmd, enum nl80211_attrs attr, + unsigned int portid, int vendor_event_idx, int approxlen, gfp_t gfp) { @@ -9304,7 +9330,7 @@ struct sk_buff *__cfg80211_alloc_event_skb(struct wiphy *wiphy, return NULL; } - return __cfg80211_alloc_vendor_skb(rdev, wdev, approxlen, 0, 0, + return __cfg80211_alloc_vendor_skb(rdev, wdev, approxlen, portid, 0, cmd, attr, info, gfp); } EXPORT_SYMBOL(__cfg80211_alloc_event_skb); @@ -9313,6 +9339,7 @@ void __cfg80211_send_event_skb(struct sk_buff *skb, gfp_t gfp) { struct cfg80211_registered_device *rdev = ((void **)skb->cb)[0]; void *hdr = ((void **)skb->cb)[1]; + struct nlmsghdr *nlhdr = nlmsg_hdr(skb); struct nlattr *data = ((void **)skb->cb)[2]; enum nl80211_multicast_groups mcgrp = NL80211_MCGRP_TESTMODE; @@ -9322,11 +9349,16 @@ void __cfg80211_send_event_skb(struct sk_buff *skb, gfp_t gfp) nla_nest_end(skb, data); genlmsg_end(skb, hdr); - if (data->nla_type == NL80211_ATTR_VENDOR_DATA) - mcgrp = NL80211_MCGRP_VENDOR; + if (nlhdr->nlmsg_pid) { + genlmsg_unicast(wiphy_net(&rdev->wiphy), skb, + nlhdr->nlmsg_pid); + } else { + if (data->nla_type == NL80211_ATTR_VENDOR_DATA) + mcgrp = NL80211_MCGRP_VENDOR; - genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), skb, 0, - mcgrp, gfp); + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), + skb, 0, mcgrp, gfp); + } } EXPORT_SYMBOL(__cfg80211_send_event_skb); @@ -9857,7 +9889,10 @@ static int nl80211_setdel_pmksa(struct sk_buff *skb, struct genl_info *info) } if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && - dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT && + !(dev->ieee80211_ptr->iftype == NL80211_IFTYPE_AP && + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AP_PMKSA_CACHING))) return -EOPNOTSUPP; switch (info->genlhdr->cmd) { @@ -12708,6 +12743,17 @@ int cfg80211_vendor_cmd_reply(struct sk_buff *skb) } EXPORT_SYMBOL_GPL(cfg80211_vendor_cmd_reply); +unsigned int cfg80211_vendor_cmd_get_sender(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + if (WARN_ON(!rdev->cur_cmd_info)) + return 0; + + return rdev->cur_cmd_info->snd_portid; +} +EXPORT_SYMBOL_GPL(cfg80211_vendor_cmd_get_sender); + static int nl80211_set_qos_map(struct sk_buff *skb, struct genl_info *info) { @@ -13047,7 +13093,9 @@ static int nl80211_external_auth(struct sk_buff *skb, struct genl_info *info) if (!rdev->ops->external_auth) return -EOPNOTSUPP; - if (!info->attrs[NL80211_ATTR_SSID]) + if (!info->attrs[NL80211_ATTR_SSID] && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) return -EINVAL; if (!info->attrs[NL80211_ATTR_BSSID]) @@ -13058,18 +13106,24 @@ static int nl80211_external_auth(struct sk_buff *skb, struct genl_info *info) memset(¶ms, 0, sizeof(params)); - params.ssid.ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); - if (params.ssid.ssid_len == 0 || - params.ssid.ssid_len > IEEE80211_MAX_SSID_LEN) - return -EINVAL; - memcpy(params.ssid.ssid, nla_data(info->attrs[NL80211_ATTR_SSID]), - params.ssid.ssid_len); + if (info->attrs[NL80211_ATTR_SSID]) { + params.ssid.ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + if (params.ssid.ssid_len == 0 || + params.ssid.ssid_len > IEEE80211_MAX_SSID_LEN) + return -EINVAL; + memcpy(params.ssid.ssid, + nla_data(info->attrs[NL80211_ATTR_SSID]), + params.ssid.ssid_len); + } memcpy(params.bssid, nla_data(info->attrs[NL80211_ATTR_BSSID]), ETH_ALEN); params.status = nla_get_u16(info->attrs[NL80211_ATTR_STATUS_CODE]); + if (info->attrs[NL80211_ATTR_PMKID]) + params.pmkid = nla_data(info->attrs[NL80211_ATTR_PMKID]); + return rdev_external_auth(rdev, dev, ¶ms); } @@ -14455,12 +14509,13 @@ static void nl80211_send_mlme_event(struct cfg80211_registered_device *rdev, struct net_device *netdev, const u8 *buf, size_t len, enum nl80211_commands cmd, gfp_t gfp, - int uapsd_queues) + int uapsd_queues, const u8 *req_ies, + size_t req_ies_len) { struct sk_buff *msg; void *hdr; - msg = nlmsg_new(100 + len, gfp); + msg = nlmsg_new(100 + len + req_ies_len, gfp); if (!msg) return; @@ -14472,7 +14527,9 @@ static void nl80211_send_mlme_event(struct cfg80211_registered_device *rdev, if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || - nla_put(msg, NL80211_ATTR_FRAME, len, buf)) + nla_put(msg, NL80211_ATTR_FRAME, len, buf) || + (req_ies && + nla_put(msg, NL80211_ATTR_REQ_IE, req_ies_len, req_ies))) goto nla_put_failure; if (uapsd_queues >= 0) { @@ -14503,15 +14560,17 @@ void nl80211_send_rx_auth(struct cfg80211_registered_device *rdev, size_t len, gfp_t gfp) { nl80211_send_mlme_event(rdev, netdev, buf, len, - NL80211_CMD_AUTHENTICATE, gfp, -1); + NL80211_CMD_AUTHENTICATE, gfp, -1, NULL, 0); } void nl80211_send_rx_assoc(struct cfg80211_registered_device *rdev, struct net_device *netdev, const u8 *buf, - size_t len, gfp_t gfp, int uapsd_queues) + size_t len, gfp_t gfp, int uapsd_queues, + const u8 *req_ies, size_t req_ies_len) { nl80211_send_mlme_event(rdev, netdev, buf, len, - NL80211_CMD_ASSOCIATE, gfp, uapsd_queues); + NL80211_CMD_ASSOCIATE, gfp, uapsd_queues, + req_ies, req_ies_len); } void nl80211_send_deauth(struct cfg80211_registered_device *rdev, @@ -14519,7 +14578,7 @@ void nl80211_send_deauth(struct cfg80211_registered_device *rdev, size_t len, gfp_t gfp) { nl80211_send_mlme_event(rdev, netdev, buf, len, - NL80211_CMD_DEAUTHENTICATE, gfp, -1); + NL80211_CMD_DEAUTHENTICATE, gfp, -1, NULL, 0); } void nl80211_send_disassoc(struct cfg80211_registered_device *rdev, @@ -14527,7 +14586,7 @@ void nl80211_send_disassoc(struct cfg80211_registered_device *rdev, size_t len, gfp_t gfp) { nl80211_send_mlme_event(rdev, netdev, buf, len, - NL80211_CMD_DISASSOCIATE, gfp, -1); + NL80211_CMD_DISASSOCIATE, gfp, -1, NULL, 0); } void cfg80211_rx_unprot_mlme_mgmt(struct net_device *dev, const u8 *buf, @@ -14548,7 +14607,8 @@ void cfg80211_rx_unprot_mlme_mgmt(struct net_device *dev, const u8 *buf, cmd = NL80211_CMD_UNPROT_DISASSOCIATE; trace_cfg80211_rx_unprot_mlme_mgmt(dev, buf, len); - nl80211_send_mlme_event(rdev, dev, buf, len, cmd, GFP_ATOMIC, -1); + nl80211_send_mlme_event(rdev, dev, buf, len, cmd, GFP_ATOMIC, -1, + NULL, 0); } EXPORT_SYMBOL(cfg80211_rx_unprot_mlme_mgmt); diff --git a/net/wireless/nl80211.h b/net/wireless/nl80211.h index 531c82dcba6b..a41e94a49a89 100644 --- a/net/wireless/nl80211.h +++ b/net/wireless/nl80211.h @@ -67,7 +67,8 @@ void nl80211_send_rx_auth(struct cfg80211_registered_device *rdev, void nl80211_send_rx_assoc(struct cfg80211_registered_device *rdev, struct net_device *netdev, const u8 *buf, size_t len, gfp_t gfp, - int uapsd_queues); + int uapsd_queues, + const u8 *req_ies, size_t req_ies_len); void nl80211_send_deauth(struct cfg80211_registered_device *rdev, struct net_device *netdev, const u8 *buf, size_t len, gfp_t gfp); diff --git a/net/wireless/pmsr.c b/net/wireless/pmsr.c index de9286703280..5e2ab01d325c 100644 --- a/net/wireless/pmsr.c +++ b/net/wireless/pmsr.c @@ -256,9 +256,8 @@ int nl80211_pmsr_start(struct sk_buff *skb, struct genl_info *info) if (err) goto out_err; } else { - memcpy(req->mac_addr, nla_data(info->attrs[NL80211_ATTR_MAC]), - ETH_ALEN); - memset(req->mac_addr_mask, 0xff, ETH_ALEN); + memcpy(req->mac_addr, wdev_address(wdev), ETH_ALEN); + eth_broadcast_addr(req->mac_addr_mask); } idx = 0; @@ -272,6 +271,7 @@ int nl80211_pmsr_start(struct sk_buff *skb, struct genl_info *info) req->n_peers = count; req->cookie = cfg80211_assign_cookie(rdev); + req->nl_portid = info->snd_portid; err = rdev_start_pmsr(rdev, wdev, req); if (err) @@ -530,14 +530,14 @@ free: } EXPORT_SYMBOL_GPL(cfg80211_pmsr_report); -void cfg80211_pmsr_free_wk(struct work_struct *work) +static void cfg80211_pmsr_process_abort(struct wireless_dev *wdev) { - struct wireless_dev *wdev = container_of(work, struct wireless_dev, - pmsr_free_wk); struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); struct cfg80211_pmsr_request *req, *tmp; LIST_HEAD(free_list); + lockdep_assert_held(&wdev->mtx); + spin_lock_bh(&wdev->pmsr_lock); list_for_each_entry_safe(req, tmp, &wdev->pmsr_list, list) { if (req->nl_portid) @@ -547,14 +547,22 @@ void cfg80211_pmsr_free_wk(struct work_struct *work) spin_unlock_bh(&wdev->pmsr_lock); list_for_each_entry_safe(req, tmp, &free_list, list) { - wdev_lock(wdev); rdev_abort_pmsr(rdev, wdev, req); - wdev_unlock(wdev); kfree(req); } } +void cfg80211_pmsr_free_wk(struct work_struct *work) +{ + struct wireless_dev *wdev = container_of(work, struct wireless_dev, + pmsr_free_wk); + + wdev_lock(wdev); + cfg80211_pmsr_process_abort(wdev); + wdev_unlock(wdev); +} + void cfg80211_pmsr_wdev_down(struct wireless_dev *wdev) { struct cfg80211_pmsr_request *req; @@ -568,8 +576,8 @@ void cfg80211_pmsr_wdev_down(struct wireless_dev *wdev) spin_unlock_bh(&wdev->pmsr_lock); if (found) - schedule_work(&wdev->pmsr_free_wk); - flush_work(&wdev->pmsr_free_wk); + cfg80211_pmsr_process_abort(wdev); + WARN_ON(!list_empty(&wdev->pmsr_list)); } diff --git a/net/wireless/reg.c b/net/wireless/reg.c index ecfb1a06dbb2..2f1bf91eb226 100644 --- a/net/wireless/reg.c +++ b/net/wireless/reg.c @@ -5,7 +5,7 @@ * Copyright 2008-2011 Luis R. Rodriguez <mcgrof@qca.qualcomm.com> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright 2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018 - 2019 Intel Corporation * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -131,7 +131,8 @@ static spinlock_t reg_indoor_lock; /* Used to track the userspace process controlling the indoor setting */ static u32 reg_is_indoor_portid; -static void restore_regulatory_settings(bool reset_user); +static void restore_regulatory_settings(bool reset_user, bool cached); +static void print_regdomain(const struct ieee80211_regdomain *rd); static const struct ieee80211_regdomain *get_cfg80211_regdom(void) { @@ -263,6 +264,7 @@ static const struct ieee80211_regdomain *cfg80211_world_regdom = static char *ieee80211_regdom = "00"; static char user_alpha2[2]; +static const struct ieee80211_regdomain *cfg80211_user_regdom; module_param(ieee80211_regdom, charp, 0444); MODULE_PARM_DESC(ieee80211_regdom, "IEEE 802.11 regulatory domain code"); @@ -445,6 +447,15 @@ reg_copy_regd(const struct ieee80211_regdomain *src_regd) return regd; } +static void cfg80211_save_user_regdom(const struct ieee80211_regdomain *rd) +{ + ASSERT_RTNL(); + + if (!IS_ERR(cfg80211_user_regdom)) + kfree(cfg80211_user_regdom); + cfg80211_user_regdom = reg_copy_regd(rd); +} + struct reg_regdb_apply_request { struct list_head list; const struct ieee80211_regdomain *regdom; @@ -510,7 +521,7 @@ static void crda_timeout_work(struct work_struct *work) pr_debug("Timeout while waiting for CRDA to reply, restoring regulatory settings\n"); rtnl_lock(); reg_crda_timeouts++; - restore_regulatory_settings(true); + restore_regulatory_settings(true, false); rtnl_unlock(); } @@ -1024,8 +1035,13 @@ static void regdb_fw_cb(const struct firmware *fw, void *context) } rtnl_lock(); - if (WARN_ON(regdb && !IS_ERR(regdb))) { - /* just restore and free new db */ + if (regdb && !IS_ERR(regdb)) { + /* negative case - a bug + * positive case - can happen due to race in case of multiple cb's in + * queue, due to usage of asynchronous callback + * + * Either case, just restore and free new db. + */ } else if (set_error) { regdb = ERR_PTR(set_error); } else if (fw) { @@ -1039,7 +1055,7 @@ static void regdb_fw_cb(const struct firmware *fw, void *context) } if (restore) - restore_regulatory_settings(true); + restore_regulatory_settings(true, false); rtnl_unlock(); @@ -1255,7 +1271,7 @@ static bool is_valid_rd(const struct ieee80211_regdomain *rd) * definitions (the "2.4 GHz band", the "5 GHz band" and the "60GHz band"), * however it is safe for now to assume that a frequency rule should not be * part of a frequency's band if the start freq or end freq are off by more - * than 2 GHz for the 2.4 and 5 GHz bands, and by more than 10 GHz for the + * than 2 GHz for the 2.4 and 5 GHz bands, and by more than 20 GHz for the * 60 GHz band. * This resolution can be lowered and should be considered as we add * regulatory rule support for other "bands". @@ -1270,7 +1286,7 @@ static bool freq_in_rule_band(const struct ieee80211_freq_range *freq_range, * with the Channel starting frequency above 45 GHz. */ u32 limit = freq_khz > 45 * ONE_GHZ_IN_KHZ ? - 10 * ONE_GHZ_IN_KHZ : 2 * ONE_GHZ_IN_KHZ; + 20 * ONE_GHZ_IN_KHZ : 2 * ONE_GHZ_IN_KHZ; if (abs(freq_khz - freq_range->start_freq_khz) <= limit) return true; if (abs(freq_khz - freq_range->end_freq_khz) <= limit) @@ -2724,9 +2740,7 @@ static void notify_self_managed_wiphys(struct regulatory_request *request) list_for_each_entry(rdev, &cfg80211_rdev_list, list) { wiphy = &rdev->wiphy; if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED && - request->initiator == NL80211_REGDOM_SET_BY_USER && - request->user_reg_hint_type == - NL80211_USER_REG_HINT_CELL_BASE) + request->initiator == NL80211_REGDOM_SET_BY_USER) reg_call_notifier(wiphy, request); } } @@ -3114,7 +3128,7 @@ static void restore_custom_reg_settings(struct wiphy *wiphy) * keep their own regulatory domain on wiphy->regd so that does does * not need to be remembered. */ -static void restore_regulatory_settings(bool reset_user) +static void restore_regulatory_settings(bool reset_user, bool cached) { char alpha2[2]; char world_alpha2[2]; @@ -3173,15 +3187,41 @@ static void restore_regulatory_settings(bool reset_user) restore_custom_reg_settings(&rdev->wiphy); } - regulatory_hint_core(world_alpha2); + if (cached && (!is_an_alpha2(alpha2) || + !IS_ERR_OR_NULL(cfg80211_user_regdom))) { + reset_regdomains(false, cfg80211_world_regdom); + update_all_wiphy_regulatory(NL80211_REGDOM_SET_BY_CORE); + print_regdomain(get_cfg80211_regdom()); + nl80211_send_reg_change_event(&core_request_world); + reg_set_request_processed(); - /* - * This restores the ieee80211_regdom module parameter - * preference or the last user requested regulatory - * settings, user regulatory settings takes precedence. - */ - if (is_an_alpha2(alpha2)) - regulatory_hint_user(alpha2, NL80211_USER_REG_HINT_USER); + if (is_an_alpha2(alpha2) && + !regulatory_hint_user(alpha2, NL80211_USER_REG_HINT_USER)) { + struct regulatory_request *ureq; + + spin_lock(®_requests_lock); + ureq = list_last_entry(®_requests_list, + struct regulatory_request, + list); + list_del(&ureq->list); + spin_unlock(®_requests_lock); + + notify_self_managed_wiphys(ureq); + reg_update_last_request(ureq); + set_regdom(reg_copy_regd(cfg80211_user_regdom), + REGD_SOURCE_CACHED); + } + } else { + regulatory_hint_core(world_alpha2); + + /* + * This restores the ieee80211_regdom module parameter + * preference or the last user requested regulatory + * settings, user regulatory settings takes precedence. + */ + if (is_an_alpha2(alpha2)) + regulatory_hint_user(alpha2, NL80211_USER_REG_HINT_USER); + } spin_lock(®_requests_lock); list_splice_tail_init(&tmp_reg_req_list, ®_requests_list); @@ -3241,7 +3281,7 @@ void regulatory_hint_disconnect(void) } pr_debug("All devices are disconnected, going to restore regulatory settings\n"); - restore_regulatory_settings(false); + restore_regulatory_settings(false, true); } static bool freq_is_chan_12_13_14(u32 freq) @@ -3558,6 +3598,9 @@ int set_regdom(const struct ieee80211_regdomain *rd, bool user_reset = false; int r; + if (IS_ERR_OR_NULL(rd)) + return -ENODATA; + if (!reg_is_valid_request(rd->alpha2)) { kfree(rd); return -EINVAL; @@ -3574,6 +3617,7 @@ int set_regdom(const struct ieee80211_regdomain *rd, r = reg_set_rd_core(rd); break; case NL80211_REGDOM_SET_BY_USER: + cfg80211_save_user_regdom(rd); r = reg_set_rd_user(rd, lr); user_reset = true; break; @@ -3596,7 +3640,7 @@ int set_regdom(const struct ieee80211_regdomain *rd, break; default: /* Back to world regulatory in case of errors */ - restore_regulatory_settings(user_reset); + restore_regulatory_settings(user_reset, false); } kfree(rd); @@ -3932,6 +3976,8 @@ void regulatory_exit(void) if (!IS_ERR_OR_NULL(regdb)) kfree(regdb); + if (!IS_ERR_OR_NULL(cfg80211_user_regdom)) + kfree(cfg80211_user_regdom); free_regdb_keyring(); } diff --git a/net/wireless/reg.h b/net/wireless/reg.h index 9ceeb5f3a7cb..504133d76de4 100644 --- a/net/wireless/reg.h +++ b/net/wireless/reg.h @@ -5,6 +5,7 @@ /* * Copyright 2008-2011 Luis R. Rodriguez <mcgrof@qca.qualcomm.com> + * Copyright (C) 2019 Intel Corporation * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -22,6 +23,7 @@ enum ieee80211_regd_source { REGD_SOURCE_INTERNAL_DB, REGD_SOURCE_CRDA, + REGD_SOURCE_CACHED, }; extern const struct ieee80211_regdomain __rcu *cfg80211_regdomain; diff --git a/net/wireless/scan.c b/net/wireless/scan.c index 5123667f4569..287518c6caa4 100644 --- a/net/wireless/scan.c +++ b/net/wireless/scan.c @@ -5,6 +5,7 @@ * Copyright 2008 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright 2016 Intel Deutschland GmbH + * Copyright (C) 2018-2019 Intel Corporation */ #include <linux/kernel.h> #include <linux/slab.h> @@ -109,6 +110,12 @@ static inline void bss_ref_get(struct cfg80211_registered_device *rdev, pub); bss->refcount++; } + if (bss->pub.transmitted_bss) { + bss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, + pub); + bss->refcount++; + } } static inline void bss_ref_put(struct cfg80211_registered_device *rdev, @@ -125,6 +132,18 @@ static inline void bss_ref_put(struct cfg80211_registered_device *rdev, if (hbss->refcount == 0) bss_free(hbss); } + + if (bss->pub.transmitted_bss) { + struct cfg80211_internal_bss *tbss; + + tbss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, + pub); + tbss->refcount--; + if (tbss->refcount == 0) + bss_free(tbss); + } + bss->refcount--; if (bss->refcount == 0) bss_free(bss); @@ -150,6 +169,7 @@ static bool __cfg80211_unlink_bss(struct cfg80211_registered_device *rdev, } list_del_init(&bss->list); + list_del_init(&bss->pub.nontrans_list); rb_erase(&bss->rbn, &rdev->bss_tree); rdev->bss_entries--; WARN_ONCE((rdev->bss_entries == 0) ^ list_empty(&rdev->bss_list), @@ -159,6 +179,162 @@ static bool __cfg80211_unlink_bss(struct cfg80211_registered_device *rdev, return true; } +static size_t cfg80211_gen_new_ie(const u8 *ie, size_t ielen, + const u8 *subelement, size_t subie_len, + u8 *new_ie, gfp_t gfp) +{ + u8 *pos, *tmp; + const u8 *tmp_old, *tmp_new; + u8 *sub_copy; + + /* copy subelement as we need to change its content to + * mark an ie after it is processed. + */ + sub_copy = kmalloc(subie_len, gfp); + if (!sub_copy) + return 0; + memcpy(sub_copy, subelement, subie_len); + + pos = &new_ie[0]; + + /* set new ssid */ + tmp_new = cfg80211_find_ie(WLAN_EID_SSID, sub_copy, subie_len); + if (tmp_new) { + memcpy(pos, tmp_new, tmp_new[1] + 2); + pos += (tmp_new[1] + 2); + } + + /* go through IEs in ie (skip SSID) and subelement, + * merge them into new_ie + */ + tmp_old = cfg80211_find_ie(WLAN_EID_SSID, ie, ielen); + tmp_old = (tmp_old) ? tmp_old + tmp_old[1] + 2 : ie; + + while (tmp_old + tmp_old[1] + 2 - ie <= ielen) { + if (tmp_old[0] == 0) { + tmp_old++; + continue; + } + + if (tmp_old[0] == WLAN_EID_EXTENSION) + tmp = (u8 *)cfg80211_find_ext_ie(tmp_old[2], sub_copy, + subie_len); + else + tmp = (u8 *)cfg80211_find_ie(tmp_old[0], sub_copy, + subie_len); + + if (!tmp) { + /* ie in old ie but not in subelement */ + if (tmp_old[0] != WLAN_EID_MULTIPLE_BSSID) { + memcpy(pos, tmp_old, tmp_old[1] + 2); + pos += tmp_old[1] + 2; + } + } else { + /* ie in transmitting ie also in subelement, + * copy from subelement and flag the ie in subelement + * as copied (by setting eid field to WLAN_EID_SSID, + * which is skipped anyway). + * For vendor ie, compare OUI + type + subType to + * determine if they are the same ie. + */ + if (tmp_old[0] == WLAN_EID_VENDOR_SPECIFIC) { + if (!memcmp(tmp_old + 2, tmp + 2, 5)) { + /* same vendor ie, copy from + * subelement + */ + memcpy(pos, tmp, tmp[1] + 2); + pos += tmp[1] + 2; + tmp[0] = WLAN_EID_SSID; + } else { + memcpy(pos, tmp_old, tmp_old[1] + 2); + pos += tmp_old[1] + 2; + } + } else { + /* copy ie from subelement into new ie */ + memcpy(pos, tmp, tmp[1] + 2); + pos += tmp[1] + 2; + tmp[0] = WLAN_EID_SSID; + } + } + + if (tmp_old + tmp_old[1] + 2 - ie == ielen) + break; + + tmp_old += tmp_old[1] + 2; + } + + /* go through subelement again to check if there is any ie not + * copied to new ie, skip ssid, capability, bssid-index ie + */ + tmp_new = sub_copy; + while (tmp_new + tmp_new[1] + 2 - sub_copy <= subie_len) { + if (!(tmp_new[0] == WLAN_EID_NON_TX_BSSID_CAP || + tmp_new[0] == WLAN_EID_SSID || + tmp_new[0] == WLAN_EID_MULTI_BSSID_IDX)) { + memcpy(pos, tmp_new, tmp_new[1] + 2); + pos += tmp_new[1] + 2; + } + if (tmp_new + tmp_new[1] + 2 - sub_copy == subie_len) + break; + tmp_new += tmp_new[1] + 2; + } + + kfree(sub_copy); + return pos - new_ie; +} + +static bool is_bss(struct cfg80211_bss *a, const u8 *bssid, + const u8 *ssid, size_t ssid_len) +{ + const struct cfg80211_bss_ies *ies; + const u8 *ssidie; + + if (bssid && !ether_addr_equal(a->bssid, bssid)) + return false; + + if (!ssid) + return true; + + ies = rcu_access_pointer(a->ies); + if (!ies) + return false; + ssidie = cfg80211_find_ie(WLAN_EID_SSID, ies->data, ies->len); + if (!ssidie) + return false; + if (ssidie[1] != ssid_len) + return false; + return memcmp(ssidie + 2, ssid, ssid_len) == 0; +} + +static int +cfg80211_add_nontrans_list(struct cfg80211_bss *trans_bss, + struct cfg80211_bss *nontrans_bss) +{ + const u8 *ssid; + size_t ssid_len; + struct cfg80211_bss *bss = NULL; + + rcu_read_lock(); + ssid = ieee80211_bss_get_ie(nontrans_bss, WLAN_EID_SSID); + if (!ssid) { + rcu_read_unlock(); + return -EINVAL; + } + ssid_len = ssid[1]; + ssid = ssid + 2; + rcu_read_unlock(); + + /* check if nontrans_bss is in the list */ + list_for_each_entry(bss, &trans_bss->nontrans_list, nontrans_list) { + if (is_bss(bss, nontrans_bss->bssid, ssid, ssid_len)) + return 0; + } + + /* add to the list */ + list_add_tail(&nontrans_bss->nontrans_list, &trans_bss->nontrans_list); + return 0; +} + static void __cfg80211_bss_expire(struct cfg80211_registered_device *rdev, unsigned long expire_time) { @@ -480,73 +656,43 @@ void cfg80211_bss_expire(struct cfg80211_registered_device *rdev) __cfg80211_bss_expire(rdev, jiffies - IEEE80211_SCAN_RESULT_EXPIRE); } -const u8 *cfg80211_find_ie_match(u8 eid, const u8 *ies, int len, - const u8 *match, int match_len, - int match_offset) +const struct element * +cfg80211_find_elem_match(u8 eid, const u8 *ies, unsigned int len, + const u8 *match, unsigned int match_len, + unsigned int match_offset) { - /* match_offset can't be smaller than 2, unless match_len is - * zero, in which case match_offset must be zero as well. - */ - if (WARN_ON((match_len && match_offset < 2) || - (!match_len && match_offset))) - return NULL; + const struct element *elem; - while (len >= 2 && len >= ies[1] + 2) { - if ((ies[0] == eid) && - (ies[1] + 2 >= match_offset + match_len) && - !memcmp(ies + match_offset, match, match_len)) - return ies; - - len -= ies[1] + 2; - ies += ies[1] + 2; + for_each_element_id(elem, eid, ies, len) { + if (elem->datalen >= match_offset + match_len && + !memcmp(elem->data + match_offset, match, match_len)) + return elem; } return NULL; } -EXPORT_SYMBOL(cfg80211_find_ie_match); +EXPORT_SYMBOL(cfg80211_find_elem_match); -const u8 *cfg80211_find_vendor_ie(unsigned int oui, int oui_type, - const u8 *ies, int len) +const struct element *cfg80211_find_vendor_elem(unsigned int oui, int oui_type, + const u8 *ies, + unsigned int len) { - const u8 *ie; + const struct element *elem; u8 match[] = { oui >> 16, oui >> 8, oui, oui_type }; int match_len = (oui_type < 0) ? 3 : sizeof(match); if (WARN_ON(oui_type > 0xff)) return NULL; - ie = cfg80211_find_ie_match(WLAN_EID_VENDOR_SPECIFIC, ies, len, - match, match_len, 2); + elem = cfg80211_find_elem_match(WLAN_EID_VENDOR_SPECIFIC, ies, len, + match, match_len, 0); - if (ie && (ie[1] < 4)) + if (!elem || elem->datalen < 4) return NULL; - return ie; -} -EXPORT_SYMBOL(cfg80211_find_vendor_ie); - -static bool is_bss(struct cfg80211_bss *a, const u8 *bssid, - const u8 *ssid, size_t ssid_len) -{ - const struct cfg80211_bss_ies *ies; - const u8 *ssidie; - - if (bssid && !ether_addr_equal(a->bssid, bssid)) - return false; - - if (!ssid) - return true; - - ies = rcu_access_pointer(a->ies); - if (!ies) - return false; - ssidie = cfg80211_find_ie(WLAN_EID_SSID, ies->data, ies->len); - if (!ssidie) - return false; - if (ssidie[1] != ssid_len) - return false; - return memcmp(ssidie + 2, ssid, ssid_len) == 0; + return elem; } +EXPORT_SYMBOL(cfg80211_find_vendor_elem); /** * enum bss_compare_mode - BSS compare mode @@ -882,6 +1028,12 @@ static bool cfg80211_combine_bsses(struct cfg80211_registered_device *rdev, return true; } +struct cfg80211_non_tx_bss { + struct cfg80211_bss *tx_bss; + u8 max_bssid_indicator; + u8 bssid_index; +}; + /* Returned bss is reference counted and must be cleaned up appropriately. */ static struct cfg80211_internal_bss * cfg80211_bss_update(struct cfg80211_registered_device *rdev, @@ -985,6 +1137,8 @@ cfg80211_bss_update(struct cfg80211_registered_device *rdev, memcpy(found->pub.chain_signal, tmp->pub.chain_signal, IEEE80211_MAX_CHAINS); ether_addr_copy(found->parent_bssid, tmp->parent_bssid); + found->pub.max_bssid_indicator = tmp->pub.max_bssid_indicator; + found->pub.bssid_index = tmp->pub.bssid_index; } else { struct cfg80211_internal_bss *new; struct cfg80211_internal_bss *hidden; @@ -1009,6 +1163,7 @@ cfg80211_bss_update(struct cfg80211_registered_device *rdev, memcpy(new, tmp, sizeof(*new)); new->refcount = 1; INIT_LIST_HEAD(&new->hidden_list); + INIT_LIST_HEAD(&new->pub.nontrans_list); if (rcu_access_pointer(tmp->pub.proberesp_ies)) { hidden = rb_find_bss(rdev, tmp, BSS_CMP_HIDE_ZLEN); @@ -1042,6 +1197,17 @@ cfg80211_bss_update(struct cfg80211_registered_device *rdev, goto drop; } + /* This must be before the call to bss_ref_get */ + if (tmp->pub.transmitted_bss) { + struct cfg80211_internal_bss *pbss = + container_of(tmp->pub.transmitted_bss, + struct cfg80211_internal_bss, + pub); + + new->pub.transmitted_bss = tmp->pub.transmitted_bss; + bss_ref_get(rdev, pbss); + } + list_add_tail(&new->list, &rdev->bss_list); rdev->bss_entries++; rb_insert_bss(rdev, new); @@ -1130,14 +1296,16 @@ cfg80211_get_bss_channel(struct wiphy *wiphy, const u8 *ie, size_t ielen, } /* Returned bss is reference counted and must be cleaned up appropriately. */ -struct cfg80211_bss * -cfg80211_inform_bss_data(struct wiphy *wiphy, - struct cfg80211_inform_bss *data, - enum cfg80211_bss_frame_type ftype, - const u8 *bssid, u64 tsf, u16 capability, - u16 beacon_interval, const u8 *ie, size_t ielen, - gfp_t gfp) +static struct cfg80211_bss * +cfg80211_inform_single_bss_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + enum cfg80211_bss_frame_type ftype, + const u8 *bssid, u64 tsf, u16 capability, + u16 beacon_interval, const u8 *ie, size_t ielen, + struct cfg80211_non_tx_bss *non_tx_data, + gfp_t gfp) { + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); struct cfg80211_bss_ies *ies; struct ieee80211_channel *channel; struct cfg80211_internal_bss tmp = {}, *res; @@ -1163,6 +1331,11 @@ cfg80211_inform_bss_data(struct wiphy *wiphy, tmp.pub.beacon_interval = beacon_interval; tmp.pub.capability = capability; tmp.ts_boottime = data->boottime_ns; + if (non_tx_data) { + tmp.pub.transmitted_bss = non_tx_data->tx_bss; + tmp.pub.bssid_index = non_tx_data->bssid_index; + tmp.pub.max_bssid_indicator = non_tx_data->max_bssid_indicator; + } /* * If we do not know here whether the IEs are from a Beacon or Probe @@ -1209,19 +1382,247 @@ cfg80211_inform_bss_data(struct wiphy *wiphy, regulatory_hint_found_beacon(wiphy, channel, gfp); } + if (non_tx_data && non_tx_data->tx_bss) { + /* this is a nontransmitting bss, we need to add it to + * transmitting bss' list if it is not there + */ + if (cfg80211_add_nontrans_list(non_tx_data->tx_bss, + &res->pub)) { + if (__cfg80211_unlink_bss(rdev, res)) + rdev->bss_generation++; + } + } + trace_cfg80211_return_bss(&res->pub); /* cfg80211_bss_update gives us a referenced result */ return &res->pub; } -EXPORT_SYMBOL(cfg80211_inform_bss_data); -/* cfg80211_inform_bss_width_frame helper */ +static void cfg80211_parse_mbssid_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + enum cfg80211_bss_frame_type ftype, + const u8 *bssid, u64 tsf, + u16 beacon_interval, const u8 *ie, + size_t ielen, + struct cfg80211_non_tx_bss *non_tx_data, + gfp_t gfp) +{ + const u8 *mbssid_index_ie; + const struct element *elem, *sub; + size_t new_ie_len; + u8 new_bssid[ETH_ALEN]; + u8 *new_ie; + u16 capability; + struct cfg80211_bss *bss; + + if (!non_tx_data) + return; + if (!cfg80211_find_ie(WLAN_EID_MULTIPLE_BSSID, ie, ielen)) + return; + if (!wiphy->support_mbssid) + return; + if (wiphy->support_only_he_mbssid && + !cfg80211_find_ext_ie(WLAN_EID_EXT_HE_CAPABILITY, ie, ielen)) + return; + + new_ie = kmalloc(IEEE80211_MAX_DATA_LEN, gfp); + if (!new_ie) + return; + + for_each_element_id(elem, WLAN_EID_MULTIPLE_BSSID, ie, ielen) { + if (elem->datalen < 4) + continue; + for_each_element(sub, elem->data + 1, elem->datalen - 1) { + if (sub->id != 0 || sub->datalen < 4) { + /* not a valid BSS profile */ + continue; + } + + if (sub->data[0] != WLAN_EID_NON_TX_BSSID_CAP || + sub->data[1] != 2) { + /* The first element within the Nontransmitted + * BSSID Profile is not the Nontransmitted + * BSSID Capability element. + */ + continue; + } + + /* found a Nontransmitted BSSID Profile */ + mbssid_index_ie = cfg80211_find_ie + (WLAN_EID_MULTI_BSSID_IDX, + sub->data, sub->datalen); + if (!mbssid_index_ie || mbssid_index_ie[1] < 1 || + mbssid_index_ie[2] == 0) { + /* No valid Multiple BSSID-Index element */ + continue; + } + + non_tx_data->bssid_index = mbssid_index_ie[2]; + non_tx_data->max_bssid_indicator = elem->data[0]; + + cfg80211_gen_new_bssid(bssid, + non_tx_data->max_bssid_indicator, + non_tx_data->bssid_index, + new_bssid); + memset(new_ie, 0, IEEE80211_MAX_DATA_LEN); + new_ie_len = cfg80211_gen_new_ie(ie, ielen, sub->data, + sub->datalen, new_ie, + gfp); + if (!new_ie_len) + continue; + + capability = get_unaligned_le16(sub->data + 2); + bss = cfg80211_inform_single_bss_data(wiphy, data, + ftype, + new_bssid, tsf, + capability, + beacon_interval, + new_ie, + new_ie_len, + non_tx_data, + gfp); + if (!bss) + break; + cfg80211_put_bss(wiphy, bss); + } + } + + kfree(new_ie); +} + struct cfg80211_bss * -cfg80211_inform_bss_frame_data(struct wiphy *wiphy, - struct cfg80211_inform_bss *data, - struct ieee80211_mgmt *mgmt, size_t len, - gfp_t gfp) +cfg80211_inform_bss_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + enum cfg80211_bss_frame_type ftype, + const u8 *bssid, u64 tsf, u16 capability, + u16 beacon_interval, const u8 *ie, size_t ielen, + gfp_t gfp) +{ + struct cfg80211_bss *res; + struct cfg80211_non_tx_bss non_tx_data; + + res = cfg80211_inform_single_bss_data(wiphy, data, ftype, bssid, tsf, + capability, beacon_interval, ie, + ielen, NULL, gfp); + non_tx_data.tx_bss = res; + cfg80211_parse_mbssid_data(wiphy, data, ftype, bssid, tsf, + beacon_interval, ie, ielen, &non_tx_data, + gfp); + return res; +} +EXPORT_SYMBOL(cfg80211_inform_bss_data); + +static void +cfg80211_parse_mbssid_frame_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + struct ieee80211_mgmt *mgmt, size_t len, + struct cfg80211_non_tx_bss *non_tx_data, + gfp_t gfp) +{ + enum cfg80211_bss_frame_type ftype; + const u8 *ie = mgmt->u.probe_resp.variable; + size_t ielen = len - offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + ftype = ieee80211_is_beacon(mgmt->frame_control) ? + CFG80211_BSS_FTYPE_BEACON : CFG80211_BSS_FTYPE_PRESP; + + cfg80211_parse_mbssid_data(wiphy, data, ftype, mgmt->bssid, + le64_to_cpu(mgmt->u.probe_resp.timestamp), + le16_to_cpu(mgmt->u.probe_resp.beacon_int), + ie, ielen, non_tx_data, gfp); +} + +static void +cfg80211_update_notlisted_nontrans(struct wiphy *wiphy, + struct cfg80211_bss *nontrans_bss, + struct ieee80211_mgmt *mgmt, size_t len, + gfp_t gfp) +{ + u8 *ie, *new_ie, *pos; + const u8 *nontrans_ssid, *trans_ssid, *mbssid; + size_t ielen = len - offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + size_t new_ie_len; + struct cfg80211_bss_ies *new_ies; + const struct cfg80211_bss_ies *old; + u8 cpy_len; + + ie = mgmt->u.probe_resp.variable; + + new_ie_len = ielen; + trans_ssid = cfg80211_find_ie(WLAN_EID_SSID, ie, ielen); + if (!trans_ssid) + return; + new_ie_len -= trans_ssid[1]; + mbssid = cfg80211_find_ie(WLAN_EID_MULTIPLE_BSSID, ie, ielen); + if (!mbssid) + return; + new_ie_len -= mbssid[1]; + rcu_read_lock(); + nontrans_ssid = ieee80211_bss_get_ie(nontrans_bss, WLAN_EID_SSID); + if (!nontrans_ssid) { + rcu_read_unlock(); + return; + } + new_ie_len += nontrans_ssid[1]; + rcu_read_unlock(); + + /* generate new ie for nontrans BSS + * 1. replace SSID with nontrans BSS' SSID + * 2. skip MBSSID IE + */ + new_ie = kzalloc(new_ie_len, gfp); + if (!new_ie) + return; + new_ies = kzalloc(sizeof(*new_ies) + new_ie_len, gfp); + if (!new_ies) + goto out_free; + + pos = new_ie; + + /* copy the nontransmitted SSID */ + cpy_len = nontrans_ssid[1] + 2; + memcpy(pos, nontrans_ssid, cpy_len); + pos += cpy_len; + /* copy the IEs between SSID and MBSSID */ + cpy_len = trans_ssid[1] + 2; + memcpy(pos, (trans_ssid + cpy_len), (mbssid - (trans_ssid + cpy_len))); + pos += (mbssid - (trans_ssid + cpy_len)); + /* copy the IEs after MBSSID */ + cpy_len = mbssid[1] + 2; + memcpy(pos, mbssid + cpy_len, ((ie + ielen) - (mbssid + cpy_len))); + + /* update ie */ + new_ies->len = new_ie_len; + new_ies->tsf = le64_to_cpu(mgmt->u.probe_resp.timestamp); + new_ies->from_beacon = ieee80211_is_beacon(mgmt->frame_control); + memcpy(new_ies->data, new_ie, new_ie_len); + if (ieee80211_is_probe_resp(mgmt->frame_control)) { + old = rcu_access_pointer(nontrans_bss->proberesp_ies); + rcu_assign_pointer(nontrans_bss->proberesp_ies, new_ies); + rcu_assign_pointer(nontrans_bss->ies, new_ies); + if (old) + kfree_rcu((struct cfg80211_bss_ies *)old, rcu_head); + } else { + old = rcu_access_pointer(nontrans_bss->beacon_ies); + rcu_assign_pointer(nontrans_bss->beacon_ies, new_ies); + rcu_assign_pointer(nontrans_bss->ies, new_ies); + if (old) + kfree_rcu((struct cfg80211_bss_ies *)old, rcu_head); + } + +out_free: + kfree(new_ie); +} + +/* cfg80211_inform_bss_width_frame helper */ +static struct cfg80211_bss * +cfg80211_inform_single_bss_frame_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + struct ieee80211_mgmt *mgmt, size_t len, + struct cfg80211_non_tx_bss *non_tx_data, + gfp_t gfp) { struct cfg80211_internal_bss tmp = {}, *res; struct cfg80211_bss_ies *ies; @@ -1279,6 +1680,11 @@ cfg80211_inform_bss_frame_data(struct wiphy *wiphy, tmp.pub.chains = data->chains; memcpy(tmp.pub.chain_signal, data->chain_signal, IEEE80211_MAX_CHAINS); ether_addr_copy(tmp.parent_bssid, data->parent_bssid); + if (non_tx_data) { + tmp.pub.transmitted_bss = non_tx_data->tx_bss; + tmp.pub.bssid_index = non_tx_data->bssid_index; + tmp.pub.max_bssid_indicator = non_tx_data->max_bssid_indicator; + } signal_valid = abs(data->chan->center_freq - channel->center_freq) <= wiphy->max_adj_channel_rssi_comp; @@ -1300,6 +1706,53 @@ cfg80211_inform_bss_frame_data(struct wiphy *wiphy, /* cfg80211_bss_update gives us a referenced result */ return &res->pub; } + +struct cfg80211_bss * +cfg80211_inform_bss_frame_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + struct ieee80211_mgmt *mgmt, size_t len, + gfp_t gfp) +{ + struct cfg80211_bss *res, *tmp_bss; + const u8 *ie = mgmt->u.probe_resp.variable; + const struct cfg80211_bss_ies *ies1, *ies2; + size_t ielen = len - offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + struct cfg80211_non_tx_bss non_tx_data; + + res = cfg80211_inform_single_bss_frame_data(wiphy, data, mgmt, + len, NULL, gfp); + if (!res || !wiphy->support_mbssid || + !cfg80211_find_ie(WLAN_EID_MULTIPLE_BSSID, ie, ielen)) + return res; + if (wiphy->support_only_he_mbssid && + !cfg80211_find_ext_ie(WLAN_EID_EXT_HE_CAPABILITY, ie, ielen)) + return res; + + non_tx_data.tx_bss = res; + /* process each non-transmitting bss */ + cfg80211_parse_mbssid_frame_data(wiphy, data, mgmt, len, + &non_tx_data, gfp); + + /* check if the res has other nontransmitting bss which is not + * in MBSSID IE + */ + ies1 = rcu_access_pointer(res->ies); + + /* go through nontrans_list, if the timestamp of the BSS is + * earlier than the timestamp of the transmitting BSS then + * update it + */ + list_for_each_entry(tmp_bss, &res->nontrans_list, + nontrans_list) { + ies2 = rcu_access_pointer(tmp_bss->ies); + if (ies2->tsf < ies1->tsf) + cfg80211_update_notlisted_nontrans(wiphy, tmp_bss, + mgmt, len, gfp); + } + + return res; +} EXPORT_SYMBOL(cfg80211_inform_bss_frame_data); void cfg80211_ref_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) @@ -1337,7 +1790,8 @@ EXPORT_SYMBOL(cfg80211_put_bss); void cfg80211_unlink_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) { struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); - struct cfg80211_internal_bss *bss; + struct cfg80211_internal_bss *bss, *tmp1; + struct cfg80211_bss *nontrans_bss, *tmp; if (WARN_ON(!pub)) return; @@ -1345,10 +1799,21 @@ void cfg80211_unlink_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) bss = container_of(pub, struct cfg80211_internal_bss, pub); spin_lock_bh(&rdev->bss_lock); - if (!list_empty(&bss->list)) { - if (__cfg80211_unlink_bss(rdev, bss)) + if (list_empty(&bss->list)) + goto out; + + list_for_each_entry_safe(nontrans_bss, tmp, + &pub->nontrans_list, + nontrans_list) { + tmp1 = container_of(nontrans_bss, + struct cfg80211_internal_bss, pub); + if (__cfg80211_unlink_bss(rdev, tmp1)) rdev->bss_generation++; } + + if (__cfg80211_unlink_bss(rdev, bss)) + rdev->bss_generation++; +out: spin_unlock_bh(&rdev->bss_lock); } EXPORT_SYMBOL(cfg80211_unlink_bss); diff --git a/net/wireless/sme.c b/net/wireless/sme.c index f741d8376a46..7d34cb884840 100644 --- a/net/wireless/sme.c +++ b/net/wireless/sme.c @@ -667,7 +667,7 @@ static void disconnect_work(struct work_struct *work) rtnl_unlock(); } -static DECLARE_WORK(cfg80211_disconnect_work, disconnect_work); +DECLARE_WORK(cfg80211_disconnect_work, disconnect_work); /* diff --git a/net/wireless/util.c b/net/wireless/util.c index cd48cdd582c0..e4b8db5e81ec 100644 --- a/net/wireless/util.c +++ b/net/wireless/util.c @@ -5,7 +5,7 @@ * Copyright 2007-2009 Johannes Berg <johannes@sipsolutions.net> * Copyright 2013-2014 Intel Mobile Communications GmbH * Copyright 2017 Intel Deutschland GmbH - * Copyright (C) 2018 Intel Corporation + * Copyright (C) 2018-2019 Intel Corporation */ #include <linux/export.h> #include <linux/bitops.h> @@ -19,6 +19,7 @@ #include <linux/mpls.h> #include <linux/gcd.h> #include <linux/bitfield.h> +#include <linux/nospec.h> #include "core.h" #include "rdev-ops.h" @@ -715,20 +716,25 @@ unsigned int cfg80211_classify8021d(struct sk_buff *skb, { unsigned int dscp; unsigned char vlan_priority; + unsigned int ret; /* skb->priority values from 256->263 are magic values to * directly indicate a specific 802.1d priority. This is used * to allow 802.1d priority to be passed directly in from VLAN * tags, etc. */ - if (skb->priority >= 256 && skb->priority <= 263) - return skb->priority - 256; + if (skb->priority >= 256 && skb->priority <= 263) { + ret = skb->priority - 256; + goto out; + } if (skb_vlan_tag_present(skb)) { vlan_priority = (skb_vlan_tag_get(skb) & VLAN_PRIO_MASK) >> VLAN_PRIO_SHIFT; - if (vlan_priority > 0) - return vlan_priority; + if (vlan_priority > 0) { + ret = vlan_priority; + goto out; + } } switch (skb->protocol) { @@ -747,8 +753,9 @@ unsigned int cfg80211_classify8021d(struct sk_buff *skb, if (!mpls) return 0; - return (ntohl(mpls->entry) & MPLS_LS_TC_MASK) + ret = (ntohl(mpls->entry) & MPLS_LS_TC_MASK) >> MPLS_LS_TC_SHIFT; + goto out; } case htons(ETH_P_80221): /* 802.21 is always network control traffic */ @@ -761,22 +768,28 @@ unsigned int cfg80211_classify8021d(struct sk_buff *skb, unsigned int i, tmp_dscp = dscp >> 2; for (i = 0; i < qos_map->num_des; i++) { - if (tmp_dscp == qos_map->dscp_exception[i].dscp) - return qos_map->dscp_exception[i].up; + if (tmp_dscp == qos_map->dscp_exception[i].dscp) { + ret = qos_map->dscp_exception[i].up; + goto out; + } } for (i = 0; i < 8; i++) { if (tmp_dscp >= qos_map->up[i].low && - tmp_dscp <= qos_map->up[i].high) - return i; + tmp_dscp <= qos_map->up[i].high) { + ret = i; + goto out; + } } } - return dscp >> 5; + ret = dscp >> 5; +out: + return array_index_nospec(ret, IEEE80211_NUM_TIDS); } EXPORT_SYMBOL(cfg80211_classify8021d); -const u8 *ieee80211_bss_get_ie(struct cfg80211_bss *bss, u8 ie) +const struct element *ieee80211_bss_get_elem(struct cfg80211_bss *bss, u8 id) { const struct cfg80211_bss_ies *ies; @@ -784,9 +797,9 @@ const u8 *ieee80211_bss_get_ie(struct cfg80211_bss *bss, u8 ie) if (!ies) return NULL; - return cfg80211_find_ie(ie, ies->data, ies->len); + return cfg80211_find_elem(id, ies->data, ies->len); } -EXPORT_SYMBOL(ieee80211_bss_get_ie); +EXPORT_SYMBOL(ieee80211_bss_get_elem); void cfg80211_upload_connect_keys(struct wireless_dev *wdev) { diff --git a/net/wireless/wext-compat.c b/net/wireless/wext-compat.c index 06943d9c9835..d522787c7354 100644 --- a/net/wireless/wext-compat.c +++ b/net/wireless/wext-compat.c @@ -1337,6 +1337,7 @@ static struct iw_statistics *cfg80211_wireless_stats(struct net_device *dev) wstats.qual.qual = sig + 110; break; } + /* fall through */ case CFG80211_SIGNAL_TYPE_UNSPEC: if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_SIGNAL)) { wstats.qual.updated |= IW_QUAL_LEVEL_UPDATED; @@ -1345,6 +1346,7 @@ static struct iw_statistics *cfg80211_wireless_stats(struct net_device *dev) wstats.qual.qual = sinfo.signal; break; } + /* fall through */ default: wstats.qual.updated |= IW_QUAL_LEVEL_INVALID; wstats.qual.updated |= IW_QUAL_QUAL_INVALID; diff --git a/net/x25/af_x25.c b/net/x25/af_x25.c index 5121729b8b63..20a511398389 100644 --- a/net/x25/af_x25.c +++ b/net/x25/af_x25.c @@ -352,17 +352,15 @@ static unsigned int x25_new_lci(struct x25_neigh *nb) unsigned int lci = 1; struct sock *sk; - read_lock_bh(&x25_list_lock); - - while ((sk = __x25_find_socket(lci, nb)) != NULL) { + while ((sk = x25_find_socket(lci, nb)) != NULL) { sock_put(sk); if (++lci == 4096) { lci = 0; break; } + cond_resched(); } - read_unlock_bh(&x25_list_lock); return lci; } @@ -681,8 +679,7 @@ static int x25_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) struct sockaddr_x25 *addr = (struct sockaddr_x25 *)uaddr; int len, i, rc = 0; - if (!sock_flag(sk, SOCK_ZAPPED) || - addr_len != sizeof(struct sockaddr_x25) || + if (addr_len != sizeof(struct sockaddr_x25) || addr->sx25_family != AF_X25) { rc = -EINVAL; goto out; @@ -701,9 +698,13 @@ static int x25_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) } lock_sock(sk); - x25_sk(sk)->source_addr = addr->sx25_addr; - x25_insert_socket(sk); - sock_reset_flag(sk, SOCK_ZAPPED); + if (sock_flag(sk, SOCK_ZAPPED)) { + x25_sk(sk)->source_addr = addr->sx25_addr; + x25_insert_socket(sk); + sock_reset_flag(sk, SOCK_ZAPPED); + } else { + rc = -EINVAL; + } release_sock(sk); SOCK_DEBUG(sk, "x25_bind: socket is bound\n"); out: @@ -819,8 +820,13 @@ static int x25_connect(struct socket *sock, struct sockaddr *uaddr, sock->state = SS_CONNECTED; rc = 0; out_put_neigh: - if (rc) + if (rc) { + read_lock_bh(&x25_list_lock); x25_neigh_put(x25->neighbour); + x25->neighbour = NULL; + read_unlock_bh(&x25_list_lock); + x25->state = X25_STATE_0; + } out_put_route: x25_route_put(rt); out: diff --git a/net/xdp/Kconfig b/net/xdp/Kconfig index 90e4a7152854..0255b33cff4b 100644 --- a/net/xdp/Kconfig +++ b/net/xdp/Kconfig @@ -5,3 +5,11 @@ config XDP_SOCKETS help XDP sockets allows a channel between XDP programs and userspace applications. + +config XDP_SOCKETS_DIAG + tristate "XDP sockets: monitoring interface" + depends on XDP_SOCKETS + default n + help + Support for PF_XDP sockets monitoring interface used by the ss tool. + If unsure, say Y. diff --git a/net/xdp/Makefile b/net/xdp/Makefile index 04f073146256..59dbfdf93dca 100644 --- a/net/xdp/Makefile +++ b/net/xdp/Makefile @@ -1 +1,2 @@ obj-$(CONFIG_XDP_SOCKETS) += xsk.o xdp_umem.o xsk_queue.o +obj-$(CONFIG_XDP_SOCKETS_DIAG) += xsk_diag.o diff --git a/net/xdp/xdp_umem.c b/net/xdp/xdp_umem.c index d4de871e7d4d..77520eacee8f 100644 --- a/net/xdp/xdp_umem.c +++ b/net/xdp/xdp_umem.c @@ -13,12 +13,15 @@ #include <linux/mm.h> #include <linux/netdevice.h> #include <linux/rtnetlink.h> +#include <linux/idr.h> #include "xdp_umem.h" #include "xsk_queue.h" #define XDP_UMEM_MIN_CHUNK_SIZE 2048 +static DEFINE_IDA(umem_ida); + void xdp_add_sk_umem(struct xdp_umem *umem, struct xdp_sock *xs) { unsigned long flags; @@ -67,6 +70,7 @@ struct xdp_umem *xdp_get_umem_from_qid(struct net_device *dev, return NULL; } +EXPORT_SYMBOL(xdp_get_umem_from_qid); static void xdp_clear_umem_at_qid(struct net_device *dev, u16 queue_id) { @@ -125,9 +129,10 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev, return 0; err_unreg_umem: - xdp_clear_umem_at_qid(dev, queue_id); if (!force_zc) err = 0; /* fallback to copy mode */ + if (err) + xdp_clear_umem_at_qid(dev, queue_id); out_rtnl_unlock: rtnl_unlock(); return err; @@ -193,6 +198,8 @@ static void xdp_umem_release(struct xdp_umem *umem) xdp_umem_clear_dev(umem); + ida_simple_remove(&umem_ida, umem->id); + if (umem->fq) { xskq_destroy(umem->fq); umem->fq = NULL; @@ -259,10 +266,10 @@ static int xdp_umem_pin_pages(struct xdp_umem *umem) if (!umem->pgs) return -ENOMEM; - down_write(¤t->mm->mmap_sem); - npgs = get_user_pages(umem->address, umem->npgs, - gup_flags, &umem->pgs[0], NULL); - up_write(¤t->mm->mmap_sem); + down_read(¤t->mm->mmap_sem); + npgs = get_user_pages_longterm(umem->address, umem->npgs, + gup_flags, &umem->pgs[0], NULL); + up_read(¤t->mm->mmap_sem); if (npgs != umem->npgs) { if (npgs >= 0) { @@ -399,8 +406,16 @@ struct xdp_umem *xdp_umem_create(struct xdp_umem_reg *mr) if (!umem) return ERR_PTR(-ENOMEM); + err = ida_simple_get(&umem_ida, 0, 0, GFP_KERNEL); + if (err < 0) { + kfree(umem); + return ERR_PTR(err); + } + umem->id = err; + err = xdp_umem_reg(umem, mr); if (err) { + ida_simple_remove(&umem_ida, umem->id); kfree(umem); return ERR_PTR(err); } diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c index a03268454a27..a14e8864e4fa 100644 --- a/net/xdp/xsk.c +++ b/net/xdp/xsk.c @@ -27,14 +27,10 @@ #include "xsk_queue.h" #include "xdp_umem.h" +#include "xsk.h" #define TX_BATCH_SIZE 16 -static struct xdp_sock *xdp_sk(struct sock *sk) -{ - return (struct xdp_sock *)sk; -} - bool xsk_is_setup_for_bpf_map(struct xdp_sock *xs) { return READ_ONCE(xs->rx) && READ_ONCE(xs->umem) && @@ -350,6 +346,10 @@ static int xsk_release(struct socket *sock) net = sock_net(sk); + mutex_lock(&net->xdp.lock); + sk_del_node_init_rcu(sk); + mutex_unlock(&net->xdp.lock); + local_bh_disable(); sock_prot_inuse_add(net, sk->sk_prot, -1); local_bh_enable(); @@ -366,7 +366,6 @@ static int xsk_release(struct socket *sock) xskq_destroy(xs->rx); xskq_destroy(xs->tx); - xdp_put_umem(xs->umem); sock_orphan(sk); sock->sk = NULL; @@ -408,6 +407,10 @@ static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len) if (sxdp->sxdp_family != AF_XDP) return -EINVAL; + flags = sxdp->sxdp_flags; + if (flags & ~(XDP_SHARED_UMEM | XDP_COPY | XDP_ZEROCOPY)) + return -EINVAL; + mutex_lock(&xs->mutex); if (xs->dev) { err = -EBUSY; @@ -426,7 +429,6 @@ static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len) } qid = sxdp->sxdp_queue_id; - flags = sxdp->sxdp_flags; if (flags & XDP_SHARED_UMEM) { struct xdp_sock *umem_xs; @@ -669,6 +671,8 @@ static int xsk_mmap(struct file *file, struct socket *sock, if (!umem) return -EINVAL; + /* Matches the smp_wmb() in XDP_UMEM_REG */ + smp_rmb(); if (offset == XDP_UMEM_PGOFF_FILL_RING) q = READ_ONCE(umem->fq); else if (offset == XDP_UMEM_PGOFF_COMPLETION_RING) @@ -678,6 +682,8 @@ static int xsk_mmap(struct file *file, struct socket *sock, if (!q) return -EINVAL; + /* Matches the smp_wmb() in xsk_init_queue */ + smp_rmb(); qpg = virt_to_head_page(q->ring); if (size > (PAGE_SIZE << compound_order(qpg))) return -EINVAL; @@ -714,6 +720,18 @@ static const struct proto_ops xsk_proto_ops = { .sendpage = sock_no_sendpage, }; +static void xsk_destruct(struct sock *sk) +{ + struct xdp_sock *xs = xdp_sk(sk); + + if (!sock_flag(sk, SOCK_DEAD)) + return; + + xdp_put_umem(xs->umem); + + sk_refcnt_debug_dec(sk); +} + static int xsk_create(struct net *net, struct socket *sock, int protocol, int kern) { @@ -740,12 +758,19 @@ static int xsk_create(struct net *net, struct socket *sock, int protocol, sk->sk_family = PF_XDP; + sk->sk_destruct = xsk_destruct; + sk_refcnt_debug_inc(sk); + sock_set_flag(sk, SOCK_RCU_FREE); xs = xdp_sk(sk); mutex_init(&xs->mutex); spin_lock_init(&xs->tx_completion_lock); + mutex_lock(&net->xdp.lock); + sk_add_node_rcu(sk, &net->xdp.list); + mutex_unlock(&net->xdp.lock); + local_bh_disable(); sock_prot_inuse_add(net, &xsk_proto, 1); local_bh_enable(); @@ -759,6 +784,23 @@ static const struct net_proto_family xsk_family_ops = { .owner = THIS_MODULE, }; +static int __net_init xsk_net_init(struct net *net) +{ + mutex_init(&net->xdp.lock); + INIT_HLIST_HEAD(&net->xdp.list); + return 0; +} + +static void __net_exit xsk_net_exit(struct net *net) +{ + WARN_ON_ONCE(!hlist_empty(&net->xdp.list)); +} + +static struct pernet_operations xsk_net_ops = { + .init = xsk_net_init, + .exit = xsk_net_exit, +}; + static int __init xsk_init(void) { int err; @@ -771,8 +813,13 @@ static int __init xsk_init(void) if (err) goto out_proto; + err = register_pernet_subsys(&xsk_net_ops); + if (err) + goto out_sk; return 0; +out_sk: + sock_unregister(PF_XDP); out_proto: proto_unregister(&xsk_proto); out: diff --git a/net/xdp/xsk.h b/net/xdp/xsk.h new file mode 100644 index 000000000000..ba8120610426 --- /dev/null +++ b/net/xdp/xsk.h @@ -0,0 +1,12 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright(c) 2019 Intel Corporation. */ + +#ifndef XSK_H_ +#define XSK_H_ + +static inline struct xdp_sock *xdp_sk(struct sock *sk) +{ + return (struct xdp_sock *)sk; +} + +#endif /* XSK_H_ */ diff --git a/net/xdp/xsk_diag.c b/net/xdp/xsk_diag.c new file mode 100644 index 000000000000..d5e06c8e0cbf --- /dev/null +++ b/net/xdp/xsk_diag.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0 +/* XDP sockets monitoring support + * + * Copyright(c) 2019 Intel Corporation. + * + * Author: Björn Töpel <bjorn.topel@intel.com> + */ + +#include <linux/module.h> +#include <net/xdp_sock.h> +#include <linux/xdp_diag.h> +#include <linux/sock_diag.h> + +#include "xsk_queue.h" +#include "xsk.h" + +static int xsk_diag_put_info(const struct xdp_sock *xs, struct sk_buff *nlskb) +{ + struct xdp_diag_info di = {}; + + di.ifindex = xs->dev ? xs->dev->ifindex : 0; + di.queue_id = xs->queue_id; + return nla_put(nlskb, XDP_DIAG_INFO, sizeof(di), &di); +} + +static int xsk_diag_put_ring(const struct xsk_queue *queue, int nl_type, + struct sk_buff *nlskb) +{ + struct xdp_diag_ring dr = {}; + + dr.entries = queue->nentries; + return nla_put(nlskb, nl_type, sizeof(dr), &dr); +} + +static int xsk_diag_put_rings_cfg(const struct xdp_sock *xs, + struct sk_buff *nlskb) +{ + int err = 0; + + if (xs->rx) + err = xsk_diag_put_ring(xs->rx, XDP_DIAG_RX_RING, nlskb); + if (!err && xs->tx) + err = xsk_diag_put_ring(xs->tx, XDP_DIAG_TX_RING, nlskb); + return err; +} + +static int xsk_diag_put_umem(const struct xdp_sock *xs, struct sk_buff *nlskb) +{ + struct xdp_umem *umem = xs->umem; + struct xdp_diag_umem du = {}; + int err; + + if (!umem) + return 0; + + du.id = umem->id; + du.size = umem->size; + du.num_pages = umem->npgs; + du.chunk_size = (__u32)(~umem->chunk_mask + 1); + du.headroom = umem->headroom; + du.ifindex = umem->dev ? umem->dev->ifindex : 0; + du.queue_id = umem->queue_id; + du.flags = 0; + if (umem->zc) + du.flags |= XDP_DU_F_ZEROCOPY; + du.refs = refcount_read(&umem->users); + + err = nla_put(nlskb, XDP_DIAG_UMEM, sizeof(du), &du); + + if (!err && umem->fq) + err = xsk_diag_put_ring(umem->fq, XDP_DIAG_UMEM_FILL_RING, nlskb); + if (!err && umem->cq) { + err = xsk_diag_put_ring(umem->cq, XDP_DIAG_UMEM_COMPLETION_RING, + nlskb); + } + return err; +} + +static int xsk_diag_fill(struct sock *sk, struct sk_buff *nlskb, + struct xdp_diag_req *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags, int sk_ino) +{ + struct xdp_sock *xs = xdp_sk(sk); + struct xdp_diag_msg *msg; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(nlskb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*msg), + flags); + if (!nlh) + return -EMSGSIZE; + + msg = nlmsg_data(nlh); + memset(msg, 0, sizeof(*msg)); + msg->xdiag_family = AF_XDP; + msg->xdiag_type = sk->sk_type; + msg->xdiag_ino = sk_ino; + sock_diag_save_cookie(sk, msg->xdiag_cookie); + + if ((req->xdiag_show & XDP_SHOW_INFO) && xsk_diag_put_info(xs, nlskb)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_INFO) && + nla_put_u32(nlskb, XDP_DIAG_UID, + from_kuid_munged(user_ns, sock_i_uid(sk)))) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_RING_CFG) && + xsk_diag_put_rings_cfg(xs, nlskb)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_UMEM) && + xsk_diag_put_umem(xs, nlskb)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_MEMINFO) && + sock_diag_put_meminfo(sk, nlskb, XDP_DIAG_MEMINFO)) + goto out_nlmsg_trim; + + nlmsg_end(nlskb, nlh); + return 0; + +out_nlmsg_trim: + nlmsg_cancel(nlskb, nlh); + return -EMSGSIZE; +} + +static int xsk_diag_dump(struct sk_buff *nlskb, struct netlink_callback *cb) +{ + struct xdp_diag_req *req = nlmsg_data(cb->nlh); + struct net *net = sock_net(nlskb->sk); + int num = 0, s_num = cb->args[0]; + struct sock *sk; + + mutex_lock(&net->xdp.lock); + + sk_for_each(sk, &net->xdp.list) { + if (!net_eq(sock_net(sk), net)) + continue; + if (num++ < s_num) + continue; + + if (xsk_diag_fill(sk, nlskb, req, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + sock_i_ino(sk)) < 0) { + num--; + break; + } + } + + mutex_unlock(&net->xdp.lock); + cb->args[0] = num; + return nlskb->len; +} + +static int xsk_diag_handler_dump(struct sk_buff *nlskb, struct nlmsghdr *hdr) +{ + struct netlink_dump_control c = { .dump = xsk_diag_dump }; + int hdrlen = sizeof(struct xdp_diag_req); + struct net *net = sock_net(nlskb->sk); + + if (nlmsg_len(hdr) < hdrlen) + return -EINVAL; + + if (!(hdr->nlmsg_flags & NLM_F_DUMP)) + return -EOPNOTSUPP; + + return netlink_dump_start(net->diag_nlsk, nlskb, hdr, &c); +} + +static const struct sock_diag_handler xsk_diag_handler = { + .family = AF_XDP, + .dump = xsk_diag_handler_dump, +}; + +static int __init xsk_diag_init(void) +{ + return sock_diag_register(&xsk_diag_handler); +} + +static void __exit xsk_diag_exit(void) +{ + sock_diag_unregister(&xsk_diag_handler); +} + +module_init(xsk_diag_init); +module_exit(xsk_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, AF_XDP); diff --git a/net/xdp/xsk_queue.h b/net/xdp/xsk_queue.h index bcb5cbb40419..610c0bdc0c2b 100644 --- a/net/xdp/xsk_queue.h +++ b/net/xdp/xsk_queue.h @@ -174,8 +174,8 @@ static inline bool xskq_is_valid_desc(struct xsk_queue *q, struct xdp_desc *d) if (!xskq_is_valid_addr(q, d->addr)) return false; - if (((d->addr + d->len) & q->chunk_mask) != - (d->addr & q->chunk_mask)) { + if (((d->addr + d->len) & q->chunk_mask) != (d->addr & q->chunk_mask) || + d->options) { q->invalid_descs++; return false; } diff --git a/net/xfrm/xfrm_interface.c b/net/xfrm/xfrm_interface.c index 6be8c7df15bb..dbb3c1945b5c 100644 --- a/net/xfrm/xfrm_interface.c +++ b/net/xfrm/xfrm_interface.c @@ -76,10 +76,10 @@ static struct xfrm_if *xfrmi_decode_session(struct sk_buff *skb) int ifindex; struct xfrm_if *xi; - if (!skb->dev) + if (!secpath_exists(skb) || !skb->dev) return NULL; - xfrmn = net_generic(dev_net(skb->dev), xfrmi_net_id); + xfrmn = net_generic(xs_net(xfrm_input_state(skb)), xfrmi_net_id); ifindex = skb->dev->ifindex; for_each_xfrmi_rcu(xfrmn->xfrmi[0], xi) { diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index 934492bad8e0..8d1a898d0ba5 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -680,16 +680,6 @@ static void xfrm_hash_resize(struct work_struct *work) mutex_unlock(&hash_resize_mutex); } -static void xfrm_hash_reset_inexact_table(struct net *net) -{ - struct xfrm_pol_inexact_bin *b; - - lockdep_assert_held(&net->xfrm.xfrm_policy_lock); - - list_for_each_entry(b, &net->xfrm.inexact_bins, inexact_bins) - INIT_HLIST_HEAD(&b->hhead); -} - /* Make sure *pol can be inserted into fastbin. * Useful to check that later insert requests will be sucessful * (provided xfrm_policy_lock is held throughout). @@ -833,13 +823,13 @@ static void xfrm_policy_inexact_list_reinsert(struct net *net, u16 family) { unsigned int matched_s, matched_d; - struct hlist_node *newpos = NULL; struct xfrm_policy *policy, *p; matched_s = 0; matched_d = 0; list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) { + struct hlist_node *newpos = NULL; bool matches_s, matches_d; if (!policy->bydst_reinsert) @@ -849,16 +839,19 @@ static void xfrm_policy_inexact_list_reinsert(struct net *net, policy->bydst_reinsert = false; hlist_for_each_entry(p, &n->hhead, bydst) { - if (policy->priority >= p->priority) + if (policy->priority > p->priority) + newpos = &p->bydst; + else if (policy->priority == p->priority && + policy->pos > p->pos) newpos = &p->bydst; else break; } if (newpos) - hlist_add_behind(&policy->bydst, newpos); + hlist_add_behind_rcu(&policy->bydst, newpos); else - hlist_add_head(&policy->bydst, &n->hhead); + hlist_add_head_rcu(&policy->bydst, &n->hhead); /* paranoia checks follow. * Check that the reinserted policy matches at least @@ -893,12 +886,13 @@ static void xfrm_policy_inexact_node_reinsert(struct net *net, struct rb_root *new, u16 family) { - struct rb_node **p, *parent = NULL; struct xfrm_pol_inexact_node *node; + struct rb_node **p, *parent; /* we should not have another subtree here */ WARN_ON_ONCE(!RB_EMPTY_ROOT(&n->root)); - +restart: + parent = NULL; p = &new->rb_node; while (*p) { u8 prefixlen; @@ -918,12 +912,11 @@ static void xfrm_policy_inexact_node_reinsert(struct net *net, } else { struct xfrm_policy *tmp; - hlist_for_each_entry(tmp, &node->hhead, bydst) - tmp->bydst_reinsert = true; - hlist_for_each_entry(tmp, &n->hhead, bydst) + hlist_for_each_entry(tmp, &n->hhead, bydst) { tmp->bydst_reinsert = true; + hlist_del_rcu(&tmp->bydst); + } - INIT_HLIST_HEAD(&node->hhead); xfrm_policy_inexact_list_reinsert(net, node, family); if (node->prefixlen == n->prefixlen) { @@ -935,8 +928,7 @@ static void xfrm_policy_inexact_node_reinsert(struct net *net, kfree_rcu(n, rcu); n = node; n->prefixlen = prefixlen; - *p = new->rb_node; - parent = NULL; + goto restart; } } @@ -965,12 +957,11 @@ static void xfrm_policy_inexact_node_merge(struct net *net, family); } - hlist_for_each_entry(tmp, &v->hhead, bydst) - tmp->bydst_reinsert = true; - hlist_for_each_entry(tmp, &n->hhead, bydst) + hlist_for_each_entry(tmp, &v->hhead, bydst) { tmp->bydst_reinsert = true; + hlist_del_rcu(&tmp->bydst); + } - INIT_HLIST_HEAD(&n->hhead); xfrm_policy_inexact_list_reinsert(net, n, family); } @@ -1235,6 +1226,7 @@ static void xfrm_hash_rebuild(struct work_struct *work) } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq)); spin_lock_bh(&net->xfrm.xfrm_policy_lock); + write_seqcount_begin(&xfrm_policy_hash_generation); /* make sure that we can insert the indirect policies again before * we start with destructive action. @@ -1278,10 +1270,14 @@ static void xfrm_hash_rebuild(struct work_struct *work) } /* reset the bydst and inexact table in all directions */ - xfrm_hash_reset_inexact_table(net); - for (dir = 0; dir < XFRM_POLICY_MAX; dir++) { - INIT_HLIST_HEAD(&net->xfrm.policy_inexact[dir]); + struct hlist_node *n; + + hlist_for_each_entry_safe(policy, n, + &net->xfrm.policy_inexact[dir], + bydst_inexact_list) + hlist_del_init(&policy->bydst_inexact_list); + hmask = net->xfrm.policy_bydst[dir].hmask; odst = net->xfrm.policy_bydst[dir].table; for (i = hmask; i >= 0; i--) @@ -1313,6 +1309,9 @@ static void xfrm_hash_rebuild(struct work_struct *work) newpos = NULL; chain = policy_hash_bysel(net, &policy->selector, policy->family, dir); + + hlist_del_rcu(&policy->bydst); + if (!chain) { void *p = xfrm_policy_inexact_insert(policy, dir, 0); @@ -1334,6 +1333,7 @@ static void xfrm_hash_rebuild(struct work_struct *work) out_unlock: __xfrm_policy_inexact_flush(net); + write_seqcount_end(&xfrm_policy_hash_generation); spin_unlock_bh(&net->xfrm.xfrm_policy_lock); mutex_unlock(&hash_resize_mutex); @@ -2600,7 +2600,10 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy, dst_copy_metrics(dst1, dst); if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) { - __u32 mark = xfrm_smark_get(fl->flowi_mark, xfrm[i]); + __u32 mark = 0; + + if (xfrm[i]->props.smark.v || xfrm[i]->props.smark.m) + mark = xfrm_smark_get(fl->flowi_mark, xfrm[i]); family = xfrm[i]->props.family; dst = xfrm_dst_lookup(xfrm[i], tos, fl->flowi_oif, @@ -3311,8 +3314,10 @@ int __xfrm_policy_check(struct sock *sk, int dir, struct sk_buff *skb, if (ifcb) { xi = ifcb->decode_session(skb); - if (xi) + if (xi) { if_id = xi->p.if_id; + net = xi->net; + } } rcu_read_unlock(); diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 23c92891758a..1bb971f46fc6 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -432,7 +432,7 @@ void xfrm_state_free(struct xfrm_state *x) } EXPORT_SYMBOL(xfrm_state_free); -static void xfrm_state_gc_destroy(struct xfrm_state *x) +static void ___xfrm_state_destroy(struct xfrm_state *x) { tasklet_hrtimer_cancel(&x->mtimer); del_timer_sync(&x->rtimer); @@ -474,7 +474,7 @@ static void xfrm_state_gc_task(struct work_struct *work) synchronize_rcu(); hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) - xfrm_state_gc_destroy(x); + ___xfrm_state_destroy(x); } static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me) @@ -598,14 +598,19 @@ struct xfrm_state *xfrm_state_alloc(struct net *net) } EXPORT_SYMBOL(xfrm_state_alloc); -void __xfrm_state_destroy(struct xfrm_state *x) +void __xfrm_state_destroy(struct xfrm_state *x, bool sync) { WARN_ON(x->km.state != XFRM_STATE_DEAD); - spin_lock_bh(&xfrm_state_gc_lock); - hlist_add_head(&x->gclist, &xfrm_state_gc_list); - spin_unlock_bh(&xfrm_state_gc_lock); - schedule_work(&xfrm_state_gc_work); + if (sync) { + synchronize_rcu(); + ___xfrm_state_destroy(x); + } else { + spin_lock_bh(&xfrm_state_gc_lock); + hlist_add_head(&x->gclist, &xfrm_state_gc_list); + spin_unlock_bh(&xfrm_state_gc_lock); + schedule_work(&xfrm_state_gc_work); + } } EXPORT_SYMBOL(__xfrm_state_destroy); @@ -708,7 +713,7 @@ xfrm_dev_state_flush_secctx_check(struct net *net, struct net_device *dev, bool } #endif -int xfrm_state_flush(struct net *net, u8 proto, bool task_valid) +int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync) { int i, err = 0, cnt = 0; @@ -730,7 +735,10 @@ restart: err = xfrm_state_delete(x); xfrm_audit_state_delete(x, err ? 0 : 1, task_valid); - xfrm_state_put(x); + if (sync) + xfrm_state_put_sync(x); + else + xfrm_state_put(x); if (!err) cnt++; @@ -2215,7 +2223,7 @@ void xfrm_state_delete_tunnel(struct xfrm_state *x) if (atomic_read(&t->tunnel_users) == 2) xfrm_state_delete(t); atomic_dec(&t->tunnel_users); - xfrm_state_put(t); + xfrm_state_put_sync(t); x->tunnel = NULL; } } @@ -2375,8 +2383,8 @@ void xfrm_state_fini(struct net *net) unsigned int sz; flush_work(&net->xfrm.state_hash_work); - xfrm_state_flush(net, IPSEC_PROTO_ANY, false); flush_work(&xfrm_state_gc_work); + xfrm_state_flush(net, IPSEC_PROTO_ANY, false, true); WARN_ON(!list_empty(&net->xfrm.state_all)); diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index 277c1c46fe94..a131f9ff979e 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -1488,10 +1488,15 @@ static int validate_tmpl(int nr, struct xfrm_user_tmpl *ut, u16 family) if (!ut[i].family) ut[i].family = family; - if ((ut[i].mode == XFRM_MODE_TRANSPORT) && - (ut[i].family != prev_family)) - return -EINVAL; - + switch (ut[i].mode) { + case XFRM_MODE_TUNNEL: + case XFRM_MODE_BEET: + break; + default: + if (ut[i].family != prev_family) + return -EINVAL; + break; + } if (ut[i].mode >= XFRM_MODE_MAX) return -EINVAL; @@ -1927,7 +1932,7 @@ static int xfrm_flush_sa(struct sk_buff *skb, struct nlmsghdr *nlh, struct xfrm_usersa_flush *p = nlmsg_data(nlh); int err; - err = xfrm_state_flush(net, p->proto, true); + err = xfrm_state_flush(net, p->proto, true, false); if (err) { if (err == -ESRCH) /* empty table */ return 0; |