diff options
Diffstat (limited to 'net/vmw_vsock/hyperv_transport.c')
-rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 14ed5a344cdf..5583df708b8c 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -310,11 +310,15 @@ static void hvs_close_connection(struct vmbus_channel *chan) struct sock *sk = get_per_channel_state(chan); struct vsock_sock *vsk = vsock_sk(sk); - sk->sk_state = SS_UNCONNECTED; + lock_sock(sk); + + sk->sk_state = TCP_CLOSE; sock_set_flag(sk, SOCK_DONE); vsk->peer_shutdown |= SEND_SHUTDOWN | RCV_SHUTDOWN; sk->sk_state_change(sk); + + release_sock(sk); } static void hvs_open_connection(struct vmbus_channel *chan) @@ -344,8 +348,9 @@ static void hvs_open_connection(struct vmbus_channel *chan) if (!sk) return; - if ((conn_from_host && sk->sk_state != VSOCK_SS_LISTEN) || - (!conn_from_host && sk->sk_state != SS_CONNECTING)) + lock_sock(sk); + if ((conn_from_host && sk->sk_state != TCP_LISTEN) || + (!conn_from_host && sk->sk_state != TCP_SYN_SENT)) goto out; if (conn_from_host) { @@ -357,7 +362,7 @@ static void hvs_open_connection(struct vmbus_channel *chan) if (!new) goto out; - new->sk_state = SS_CONNECTING; + new->sk_state = TCP_SYN_SENT; vnew = vsock_sk(new); hvs_new = vnew->trans; hvs_new->chan = chan; @@ -384,7 +389,7 @@ static void hvs_open_connection(struct vmbus_channel *chan) vmbus_set_chn_rescind_callback(chan, hvs_close_connection); if (conn_from_host) { - new->sk_state = SS_CONNECTED; + new->sk_state = TCP_ESTABLISHED; sk->sk_ack_backlog++; hvs_addr_init(&vnew->local_addr, if_type); @@ -395,11 +400,9 @@ static void hvs_open_connection(struct vmbus_channel *chan) vsock_insert_connected(vnew); - lock_sock(sk); vsock_enqueue_accept(sk, new); - release_sock(sk); } else { - sk->sk_state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; sk->sk_socket->state = SS_CONNECTED; vsock_insert_connected(vsock_sk(sk)); @@ -410,6 +413,8 @@ static void hvs_open_connection(struct vmbus_channel *chan) out: /* Release refcnt obtained when we called vsock_find_bound_socket() */ sock_put(sk); + + release_sock(sk); } static u32 hvs_get_local_cid(void) @@ -476,13 +481,21 @@ out: static void hvs_release(struct vsock_sock *vsk) { + struct sock *sk = sk_vsock(vsk); struct hvsock *hvs = vsk->trans; - struct vmbus_channel *chan = hvs->chan; + struct vmbus_channel *chan; + + lock_sock(sk); + + sk->sk_state = SS_DISCONNECTING; + vsock_remove_sock(vsk); + + release_sock(sk); + chan = hvs->chan; if (chan) hvs_shutdown(vsk, RCV_SHUTDOWN | SEND_SHUTDOWN); - vsock_remove_sock(vsk); } static void hvs_destruct(struct vsock_sock *vsk) |