summaryrefslogtreecommitdiff
path: root/net/mptcp/protocol.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/mptcp/protocol.c')
-rw-r--r--net/mptcp/protocol.c16
1 files changed, 16 insertions, 0 deletions
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index e08a25eabcd5..3f66b6a3bb28 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -201,6 +201,7 @@ static void mptcp_close(struct sock *sk, long timeout)
struct mptcp_subflow_context *subflow, *tmp;
struct mptcp_sock *msk = mptcp_sk(sk);
+ mptcp_token_destroy(msk->token);
inet_sk_state_store(sk, TCP_CLOSE);
lock_sock(sk);
@@ -281,8 +282,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
msk = mptcp_sk(new_mptcp_sock);
msk->remote_key = subflow->remote_key;
msk->local_key = subflow->local_key;
+ msk->token = subflow->token;
msk->subflow = NULL;
+ mptcp_token_update_accept(newsk, new_mptcp_sock);
newsk = new_mptcp_sock;
mptcp_copy_inaddrs(newsk, ssk);
list_add(&subflow->node, &msk->conn_list);
@@ -299,6 +302,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
return newsk;
}
+static void mptcp_destroy(struct sock *sk)
+{
+}
+
static int mptcp_get_port(struct sock *sk, unsigned short snum)
{
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -331,6 +338,7 @@ void mptcp_finish_connect(struct sock *ssk)
*/
WRITE_ONCE(msk->remote_key, subflow->remote_key);
WRITE_ONCE(msk->local_key, subflow->local_key);
+ WRITE_ONCE(msk->token, subflow->token);
}
static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
@@ -349,6 +357,7 @@ static struct proto mptcp_prot = {
.close = mptcp_close,
.accept = mptcp_accept,
.shutdown = tcp_shutdown,
+ .destroy = mptcp_destroy,
.sendmsg = mptcp_sendmsg,
.recvmsg = mptcp_recvmsg,
.hash = inet_hash,
@@ -568,6 +577,12 @@ void __init mptcp_init(void)
static struct proto_ops mptcp_v6_stream_ops;
static struct proto mptcp_v6_prot;
+static void mptcp_v6_destroy(struct sock *sk)
+{
+ mptcp_destroy(sk);
+ inet6_destroy_sock(sk);
+}
+
static struct inet_protosw mptcp_v6_protosw = {
.type = SOCK_STREAM,
.protocol = IPPROTO_MPTCP,
@@ -583,6 +598,7 @@ int mptcpv6_init(void)
mptcp_v6_prot = mptcp_prot;
strcpy(mptcp_v6_prot.name, "MPTCPv6");
mptcp_v6_prot.slab = NULL;
+ mptcp_v6_prot.destroy = mptcp_v6_destroy;
mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
sizeof(struct ipv6_pinfo);