diff options
-rw-r--r-- | net/mptcp/pm_kernel.c | 82 |
1 files changed, 44 insertions, 38 deletions
diff --git a/net/mptcp/pm_kernel.c b/net/mptcp/pm_kernel.c index 117f842fe18e..55dbf89d19b8 100644 --- a/net/mptcp/pm_kernel.c +++ b/net/mptcp/pm_kernel.c @@ -268,6 +268,46 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info) return NULL; } +static u8 mptcp_endp_get_local_id(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + return msk->mpc_endpoint_id == addr->id ? 0 : addr->id; +} + +/* Set mpc_endpoint_id, and send MP_PRIO for ID0 if needed */ +static void mptcp_mpc_endpoint_setup(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_pm_addr_entry *entry; + struct mptcp_addr_info mpc_addr; + struct pm_nl_pernet *pernet; + bool backup = false; + + /* do lazy endpoint usage accounting for the MPC subflows */ + if (likely(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED)) || + !msk->first) + return; + + subflow = mptcp_subflow_ctx(msk->first); + pernet = pm_nl_get_pernet_from_msk(msk); + + mptcp_local_address((struct sock_common *)msk->first, &mpc_addr); + rcu_read_lock(); + entry = __lookup_addr(pernet, &mpc_addr); + if (entry) { + __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap); + msk->mpc_endpoint_id = entry->addr.id; + backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); + } + rcu_read_unlock(); + + /* Send MP_PRIO */ + if (backup) + mptcp_pm_send_ack(msk, subflow, true, backup); + + msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED); +} + static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) { u8 limit_extra_subflows = mptcp_pm_get_limit_extra_subflows(msk); @@ -278,28 +318,7 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) bool signal_and_subflow = false; struct mptcp_pm_local local; - /* do lazy endpoint usage accounting for the MPC subflows */ - if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) { - struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(msk->first); - struct mptcp_pm_addr_entry *entry; - struct mptcp_addr_info mpc_addr; - bool backup = false; - - mptcp_local_address((struct sock_common *)msk->first, &mpc_addr); - rcu_read_lock(); - entry = __lookup_addr(pernet, &mpc_addr); - if (entry) { - __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap); - msk->mpc_endpoint_id = entry->addr.id; - backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); - } - rcu_read_unlock(); - - if (backup) - mptcp_pm_send_ack(msk, subflow, true, backup); - - msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED); - } + mptcp_mpc_endpoint_setup(msk); pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", msk->pm.local_addr_used, endp_subflow_max, @@ -396,12 +415,9 @@ fill_local_addresses_vec_fullmesh(struct mptcp_sock *msk, struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); struct sock *sk = (struct sock *)msk; struct mptcp_pm_addr_entry *entry; - struct mptcp_addr_info mpc_addr; struct mptcp_pm_local *local; int i = 0; - mptcp_local_address((struct sock_common *)msk, &mpc_addr); - rcu_read_lock(); list_for_each_entry_rcu(entry, &pernet->endp_list, list) { bool is_id0; @@ -417,8 +433,7 @@ fill_local_addresses_vec_fullmesh(struct mptcp_sock *msk, local->flags = entry->flags; local->ifindex = entry->ifindex; - is_id0 = mptcp_addresses_equal(&local->addr, &mpc_addr, - local->addr.port); + is_id0 = local->addr.id == msk->mpc_endpoint_id; if (c_flag_case && (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) { @@ -452,12 +467,9 @@ fill_local_addresses_vec_c_flag(struct mptcp_sock *msk, struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); u8 endp_subflow_max = mptcp_pm_get_endp_subflow_max(msk); struct sock *sk = (struct sock *)msk; - struct mptcp_addr_info mpc_addr; struct mptcp_pm_local *local; int i = 0; - mptcp_local_address((struct sock_common *)msk, &mpc_addr); - while (msk->pm.local_addr_used < endp_subflow_max) { local = &locals[i]; @@ -469,8 +481,7 @@ fill_local_addresses_vec_c_flag(struct mptcp_sock *msk, if (!mptcp_pm_addr_families_match(sk, &local->addr, remote)) continue; - if (mptcp_addresses_equal(&local->addr, &mpc_addr, - local->addr.port)) + if (local->addr.id == msk->mpc_endpoint_id) continue; msk->pm.local_addr_used++; @@ -548,6 +559,7 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) remote = msk->pm.remote; mptcp_pm_announce_addr(msk, &remote, true); mptcp_pm_addr_send_ack(msk); + mptcp_mpc_endpoint_setup(msk); if (lookup_subflow_by_daddr(&msk->conn_list, &remote)) return; @@ -935,12 +947,6 @@ out_free: return ret; } -static u8 mptcp_endp_get_local_id(struct mptcp_sock *msk, - const struct mptcp_addr_info *addr) -{ - return msk->mpc_endpoint_id == addr->id ? 0 : addr->id; -} - static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, const struct mptcp_addr_info *addr, bool force) |