diff options
-rw-r--r-- | drivers/vhost/Kconfig | 4 | ||||
-rw-r--r-- | drivers/vhost/Kconfig.vsock | 7 | ||||
-rw-r--r-- | drivers/vhost/Makefile | 4 | ||||
-rw-r--r-- | drivers/vhost/vsock.c | 631 | ||||
-rw-r--r-- | drivers/vhost/vsock.h | 4 | ||||
-rw-r--r-- | include/linux/virtio_vsock.h | 209 | ||||
-rw-r--r-- | include/net/af_vsock.h | 2 | ||||
-rw-r--r-- | include/uapi/linux/virtio_ids.h | 1 | ||||
-rw-r--r-- | include/uapi/linux/virtio_vsock.h | 89 | ||||
-rw-r--r-- | net/vmw_vsock/Kconfig | 18 | ||||
-rw-r--r-- | net/vmw_vsock/Makefile | 2 | ||||
-rw-r--r-- | net/vmw_vsock/af_vsock.c | 70 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport.c | 466 | ||||
-rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 1272 |
14 files changed, 2779 insertions, 0 deletions
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig index 533eaf04f12f..81449bfc8d3b 100644 --- a/drivers/vhost/Kconfig +++ b/drivers/vhost/Kconfig @@ -47,3 +47,7 @@ config VHOST_CROSS_ENDIAN_LEGACY adds some overhead, it is disabled by default. If unsure, say "N". + +if STAGING +source "drivers/vhost/Kconfig.vsock" +endif diff --git a/drivers/vhost/Kconfig.vsock b/drivers/vhost/Kconfig.vsock new file mode 100644 index 000000000000..3491865d3eb9 --- /dev/null +++ b/drivers/vhost/Kconfig.vsock @@ -0,0 +1,7 @@ +config VHOST_VSOCK + tristate "vhost virtio-vsock driver" + depends on VSOCKETS && EVENTFD + select VIRTIO_VSOCKETS_COMMON + default n + ---help--- + Say M here to enable the vhost-vsock for virtio-vsock guests diff --git a/drivers/vhost/Makefile b/drivers/vhost/Makefile index e0441c34db1c..6b012b986b57 100644 --- a/drivers/vhost/Makefile +++ b/drivers/vhost/Makefile @@ -4,5 +4,9 @@ vhost_net-y := net.o obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o vhost_scsi-y := scsi.o +obj-$(CONFIG_VHOST_VSOCK) += vhost_vsock.o +vhost_vsock-y := vsock.o + obj-$(CONFIG_VHOST_RING) += vringh.o + obj-$(CONFIG_VHOST) += vhost.o diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c new file mode 100644 index 000000000000..65b1cf8a06cb --- /dev/null +++ b/drivers/vhost/vsock.c @@ -0,0 +1,631 @@ +/* + * vhost transport for vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * This work is licensed under the terms of the GNU GPL, version 2. + */ +#include <linux/miscdevice.h> +#include <linux/module.h> +#include <linux/mutex.h> +#include <net/sock.h> +#include <linux/virtio_vsock.h> +#include <linux/vhost.h> + +#include <net/af_vsock.h> +#include "vhost.h" +#include "vsock.h" + +#define VHOST_VSOCK_DEFAULT_HOST_CID 2 + +static int vhost_transport_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk); + +enum { + VHOST_VSOCK_FEATURES = VHOST_FEATURES, +}; + +/* Used to track all the vhost_vsock instances on the system. */ +static LIST_HEAD(vhost_vsock_list); +static DEFINE_MUTEX(vhost_vsock_mutex); + +struct vhost_vsock_virtqueue { + struct vhost_virtqueue vq; +}; + +struct vhost_vsock { + /* Vhost device */ + struct vhost_dev dev; + /* Vhost vsock virtqueue*/ + struct vhost_vsock_virtqueue vqs[VSOCK_VQ_MAX]; + /* Link to global vhost_vsock_list*/ + struct list_head list; + /* Head for pkt from host to guest */ + struct list_head send_pkt_list; + /* Work item to send pkt */ + struct vhost_work send_pkt_work; + /* Wait queue for send pkt */ + wait_queue_head_t queue_wait; + /* Used for global tx buf limitation */ + u32 total_tx_buf; + /* Guest contex id this vhost_vsock instance handles */ + u32 guest_cid; +}; + +static u32 vhost_transport_get_local_cid(void) +{ + u32 cid = VHOST_VSOCK_DEFAULT_HOST_CID; + return cid; +} + +static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) +{ + struct vhost_vsock *vsock; + + mutex_lock(&vhost_vsock_mutex); + list_for_each_entry(vsock, &vhost_vsock_list, list) { + if (vsock->guest_cid == guest_cid) { + mutex_unlock(&vhost_vsock_mutex); + return vsock; + } + } + mutex_unlock(&vhost_vsock_mutex); + + return NULL; +} + +static void +vhost_transport_do_send_pkt(struct vhost_vsock *vsock, + struct vhost_virtqueue *vq) +{ + bool added = false; + + mutex_lock(&vq->mutex); + vhost_disable_notify(&vsock->dev, vq); + for (;;) { + struct virtio_vsock_pkt *pkt; + struct iov_iter iov_iter; + unsigned out, in; + struct sock *sk; + size_t nbytes; + size_t len; + int head; + + if (list_empty(&vsock->send_pkt_list)) { + vhost_enable_notify(&vsock->dev, vq); + break; + } + + head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), + &out, &in, NULL, NULL); + pr_debug("%s: head = %d\n", __func__, head); + if (head < 0) + break; + + if (head == vq->num) { + if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { + vhost_disable_notify(&vsock->dev, vq); + continue; + } + break; + } + + pkt = list_first_entry(&vsock->send_pkt_list, + struct virtio_vsock_pkt, list); + list_del_init(&pkt->list); + + if (out) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Expected 0 output buffers, got %u\n", out); + break; + } + + len = iov_length(&vq->iov[out], in); + iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len); + + nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); + if (nbytes != sizeof(pkt->hdr)) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Faulted on copying pkt hdr\n"); + break; + } + + nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter); + if (nbytes != pkt->len) { + virtio_transport_free_pkt(pkt); + vq_err(vq, "Faulted on copying pkt buf\n"); + break; + } + + vhost_add_used(vq, head, pkt->len); /* TODO should this be sizeof(pkt->hdr) + pkt->len? */ + added = true; + + virtio_transport_dec_tx_pkt(pkt); + vsock->total_tx_buf -= pkt->len; + + sk = sk_vsock(pkt->trans->vsk); + /* Release refcnt taken in vhost_transport_send_pkt */ + sock_put(sk); + + virtio_transport_free_pkt(pkt); + } + if (added) + vhost_signal(&vsock->dev, vq); + mutex_unlock(&vq->mutex); + + if (added) + wake_up(&vsock->queue_wait); +} + +static void vhost_transport_send_pkt_work(struct vhost_work *work) +{ + struct vhost_virtqueue *vq; + struct vhost_vsock *vsock; + + vsock = container_of(work, struct vhost_vsock, send_pkt_work); + vq = &vsock->vqs[VSOCK_VQ_RX].vq; + + vhost_transport_do_send_pkt(vsock, vq); +} + +static int +vhost_transport_send_pkt(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info) +{ + u32 src_cid, src_port, dst_cid, dst_port; + struct virtio_transport *trans; + struct virtio_vsock_pkt *pkt; + struct vhost_virtqueue *vq; + struct vhost_vsock *vsock; + u32 pkt_len = info->pkt_len; + DEFINE_WAIT(wait); + + src_cid = vhost_transport_get_local_cid(); + src_port = vsk->local_addr.svm_port; + if (!info->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = info->remote_cid; + dst_port = info->remote_port; + } + + /* Find the vhost_vsock according to guest context id */ + vsock = vhost_vsock_get(dst_cid); + if (!vsock) + return -ENODEV; + + trans = vsk->trans; + vq = &vsock->vqs[VSOCK_VQ_RX].vq; + + /* we can send less than pkt_len bytes */ + if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) + pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; + + /* virtio_transport_get_credit might return less than pkt_len credit */ + pkt_len = virtio_transport_get_credit(trans, pkt_len); + + /* Do not send zero length OP_RW pkt*/ + if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) + return pkt_len; + + /* Respect global tx buf limitation */ + mutex_lock(&vq->mutex); + while (pkt_len + vsock->total_tx_buf > VIRTIO_VSOCK_MAX_TX_BUF_SIZE) { + prepare_to_wait_exclusive(&vsock->queue_wait, &wait, + TASK_UNINTERRUPTIBLE); + mutex_unlock(&vq->mutex); + schedule(); + mutex_lock(&vq->mutex); + finish_wait(&vsock->queue_wait, &wait); + } + vsock->total_tx_buf += pkt_len; + mutex_unlock(&vq->mutex); + + pkt = virtio_transport_alloc_pkt(vsk, info, pkt_len, + src_cid, src_port, + dst_cid, dst_port); + if (!pkt) { + mutex_lock(&vq->mutex); + vsock->total_tx_buf -= pkt_len; + mutex_unlock(&vq->mutex); + virtio_transport_put_credit(trans, pkt_len); + return -ENOMEM; + } + + pr_debug("%s:info->pkt_len= %d\n", __func__, pkt_len); + /* Released in vhost_transport_do_send_pkt */ + sock_hold(&trans->vsk->sk); + virtio_transport_inc_tx_pkt(pkt); + + /* Queue it up in vhost work */ + mutex_lock(&vq->mutex); + list_add_tail(&pkt->list, &vsock->send_pkt_list); + vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); + mutex_unlock(&vq->mutex); + + return pkt_len; +} + +static struct virtio_transport_pkt_ops vhost_ops = { + .send_pkt = vhost_transport_send_pkt, +}; + +static struct virtio_vsock_pkt * +vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, + unsigned int out, unsigned int in) +{ + struct virtio_vsock_pkt *pkt; + struct iov_iter iov_iter; + size_t nbytes; + size_t len; + + if (in != 0) { + vq_err(vq, "Expected 0 input buffers, got %u\n", in); + return NULL; + } + + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return NULL; + + len = iov_length(vq->iov, out); + iov_iter_init(&iov_iter, WRITE, vq->iov, out, len); + + nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); + if (nbytes != sizeof(pkt->hdr)) { + vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n", + sizeof(pkt->hdr), nbytes); + kfree(pkt); + return NULL; + } + + if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_DGRAM) + pkt->len = le32_to_cpu(pkt->hdr.len) & 0XFFFF; + else if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) + pkt->len = le32_to_cpu(pkt->hdr.len); + + /* No payload */ + if (!pkt->len) + return pkt; + + /* The pkt is too big */ + if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { + kfree(pkt); + return NULL; + } + + pkt->buf = kmalloc(pkt->len, GFP_KERNEL); + if (!pkt->buf) { + kfree(pkt); + return NULL; + } + + nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); + if (nbytes != pkt->len) { + vq_err(vq, "Expected %u byte payload, got %zu bytes\n", + pkt->len, nbytes); + virtio_transport_free_pkt(pkt); + return NULL; + } + + return pkt; +} + +static void vhost_vsock_handle_ctl_kick(struct vhost_work *work) +{ + struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, + poll.work); + struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, + dev); + + pr_debug("%s vq=%p, vsock=%p\n", __func__, vq, vsock); +} + +static void vhost_vsock_handle_tx_kick(struct vhost_work *work) +{ + struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, + poll.work); + struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, + dev); + struct virtio_vsock_pkt *pkt; + int head; + unsigned int out, in; + bool added = false; + u32 len; + + mutex_lock(&vq->mutex); + vhost_disable_notify(&vsock->dev, vq); + for (;;) { + head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), + &out, &in, NULL, NULL); + if (head < 0) + break; + + if (head == vq->num) { + if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { + vhost_disable_notify(&vsock->dev, vq); + continue; + } + break; + } + + pkt = vhost_vsock_alloc_pkt(vq, out, in); + if (!pkt) { + vq_err(vq, "Faulted on pkt\n"); + continue; + } + + len = pkt->len; + + /* Only accept correctly addressed packets */ + if (le32_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid && + le32_to_cpu(pkt->hdr.dst_cid) == vhost_transport_get_local_cid()) + virtio_transport_recv_pkt(pkt); + else + virtio_transport_free_pkt(pkt); + + vhost_add_used(vq, head, len); + added = true; + } + if (added) + vhost_signal(&vsock->dev, vq); + mutex_unlock(&vq->mutex); +} + +static void vhost_vsock_handle_rx_kick(struct vhost_work *work) +{ + struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, + poll.work); + struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, + dev); + + vhost_transport_do_send_pkt(vsock, vq); +} + +static int vhost_vsock_dev_open(struct inode *inode, struct file *file) +{ + struct vhost_virtqueue **vqs; + struct vhost_vsock *vsock; + int ret; + + vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); + if (!vsock) + return -ENOMEM; + + pr_debug("%s:vsock=%p\n", __func__, vsock); + + vqs = kmalloc(VSOCK_VQ_MAX * sizeof(*vqs), GFP_KERNEL); + if (!vqs) { + ret = -ENOMEM; + goto out; + } + + vqs[VSOCK_VQ_CTRL] = &vsock->vqs[VSOCK_VQ_CTRL].vq; + vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX].vq; + vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX].vq; + vsock->vqs[VSOCK_VQ_CTRL].vq.handle_kick = vhost_vsock_handle_ctl_kick; + vsock->vqs[VSOCK_VQ_TX].vq.handle_kick = vhost_vsock_handle_tx_kick; + vsock->vqs[VSOCK_VQ_RX].vq.handle_kick = vhost_vsock_handle_rx_kick; + + vhost_dev_init(&vsock->dev, vqs, VSOCK_VQ_MAX); + + file->private_data = vsock; + init_waitqueue_head(&vsock->queue_wait); + INIT_LIST_HEAD(&vsock->send_pkt_list); + vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); + + mutex_lock(&vhost_vsock_mutex); + list_add_tail(&vsock->list, &vhost_vsock_list); + mutex_unlock(&vhost_vsock_mutex); + return 0; + +out: + kfree(vsock); + return ret; +} + +static void vhost_vsock_flush(struct vhost_vsock *vsock) +{ + int i; + + for (i = 0; i < VSOCK_VQ_MAX; i++) + vhost_poll_flush(&vsock->vqs[i].vq.poll); + vhost_work_flush(&vsock->dev, &vsock->send_pkt_work); +} + +static int vhost_vsock_dev_release(struct inode *inode, struct file *file) +{ + struct vhost_vsock *vsock = file->private_data; + + mutex_lock(&vhost_vsock_mutex); + list_del(&vsock->list); + mutex_unlock(&vhost_vsock_mutex); + + vhost_dev_stop(&vsock->dev); + vhost_vsock_flush(vsock); + vhost_dev_cleanup(&vsock->dev, false); + kfree(vsock->dev.vqs); + kfree(vsock); + return 0; +} + +static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u32 guest_cid) +{ + struct vhost_vsock *other; + + /* Refuse reserved CIDs */ + if (guest_cid <= VMADDR_CID_HOST) { + return -EINVAL; + } + + /* Refuse if CID is already in use */ + other = vhost_vsock_get(guest_cid); + if (other && other != vsock) { + return -EADDRINUSE; + } + + mutex_lock(&vhost_vsock_mutex); + vsock->guest_cid = guest_cid; + pr_debug("%s:guest_cid=%d\n", __func__, guest_cid); + mutex_unlock(&vhost_vsock_mutex); + + return 0; +} + +static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) +{ + struct vhost_virtqueue *vq; + int i; + + if (features & ~VHOST_VSOCK_FEATURES) + return -EOPNOTSUPP; + + mutex_lock(&vsock->dev.mutex); + if ((features & (1 << VHOST_F_LOG_ALL)) && + !vhost_log_access_ok(&vsock->dev)) { + mutex_unlock(&vsock->dev.mutex); + return -EFAULT; + } + + for (i = 0; i < VSOCK_VQ_MAX; i++) { + vq = &vsock->vqs[i].vq; + mutex_lock(&vq->mutex); + vq->acked_features = features; + mutex_unlock(&vq->mutex); + } + mutex_unlock(&vsock->dev.mutex); + return 0; +} + +static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl, + unsigned long arg) +{ + struct vhost_vsock *vsock = f->private_data; + void __user *argp = (void __user *)arg; + u64 __user *featurep = argp; + u32 __user *cidp = argp; + u32 guest_cid; + u64 features; + int r; + + switch (ioctl) { + case VHOST_VSOCK_SET_GUEST_CID: + if (get_user(guest_cid, cidp)) + return -EFAULT; + return vhost_vsock_set_cid(vsock, guest_cid); + case VHOST_GET_FEATURES: + features = VHOST_VSOCK_FEATURES; + if (copy_to_user(featurep, &features, sizeof(features))) + return -EFAULT; + return 0; + case VHOST_SET_FEATURES: + if (copy_from_user(&features, featurep, sizeof(features))) + return -EFAULT; + return vhost_vsock_set_features(vsock, features); + default: + mutex_lock(&vsock->dev.mutex); + r = vhost_dev_ioctl(&vsock->dev, ioctl, argp); + if (r == -ENOIOCTLCMD) + r = vhost_vring_ioctl(&vsock->dev, ioctl, argp); + else + vhost_vsock_flush(vsock); + mutex_unlock(&vsock->dev.mutex); + return r; + } +} + +static const struct file_operations vhost_vsock_fops = { + .owner = THIS_MODULE, + .open = vhost_vsock_dev_open, + .release = vhost_vsock_dev_release, + .llseek = noop_llseek, + .unlocked_ioctl = vhost_vsock_dev_ioctl, +}; + +static struct miscdevice vhost_vsock_misc = { + .minor = MISC_DYNAMIC_MINOR, + .name = "vhost-vsock", + .fops = &vhost_vsock_fops, +}; + +static int +vhost_transport_socket_init(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + struct virtio_transport *trans; + int ret; + + ret = virtio_transport_do_socket_init(vsk, psk); + if (ret) + return ret; + + trans = vsk->trans; + trans->ops = &vhost_ops; + + return ret; +} + +static struct vsock_transport vhost_transport = { + .get_local_cid = vhost_transport_get_local_cid, + + .init = vhost_transport_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_bind = virtio_transport_dgram_bind, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + + .set_buffer_size = virtio_transport_set_buffer_size, + .set_min_buffer_size = virtio_transport_set_min_buffer_size, + .set_max_buffer_size = virtio_transport_set_max_buffer_size, + .get_buffer_size = virtio_transport_get_buffer_size, + .get_min_buffer_size = virtio_transport_get_min_buffer_size, + .get_max_buffer_size = virtio_transport_get_max_buffer_size, +}; + +static int __init vhost_vsock_init(void) +{ + int ret; + + ret = vsock_core_init(&vhost_transport); + if (ret < 0) + return ret; + return misc_register(&vhost_vsock_misc); +}; + +static void __exit vhost_vsock_exit(void) +{ + misc_deregister(&vhost_vsock_misc); + vsock_core_exit(); +}; + +module_init(vhost_vsock_init); +module_exit(vhost_vsock_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("vhost transport for vsock "); diff --git a/drivers/vhost/vsock.h b/drivers/vhost/vsock.h new file mode 100644 index 000000000000..0ddb107b86ca --- /dev/null +++ b/drivers/vhost/vsock.h @@ -0,0 +1,4 @@ +#ifndef VHOST_VSOCK_H +#define VHOST_VSOCK_H +#define VHOST_VSOCK_SET_GUEST_CID _IOW(VHOST_VIRTIO, 0x60, __u32) +#endif diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h new file mode 100644 index 000000000000..a5f3ecc038f7 --- /dev/null +++ b/include/linux/virtio_vsock.h @@ -0,0 +1,209 @@ +/* + * This header, excluding the #ifdef __KERNEL__ part, is BSD licensed so + * anyone can use the definitions to implement compatible drivers/servers: + * + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of IBM nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL IBM OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * Copyright (C) Red Hat, Inc., 2013-2015 + * Copyright (C) Asias He <asias@redhat.com>, 2013 + * Copyright (C) Stefan Hajnoczi <stefanha@redhat.com>, 2015 + */ + +#ifndef _LINUX_VIRTIO_VSOCK_H +#define _LINUX_VIRTIO_VSOCK_H + +#include <uapi/linux/virtio_vsock.h> +#include <linux/socket.h> +#include <net/sock.h> + +#define VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE 128 +#define VIRTIO_VSOCK_DEFAULT_BUF_SIZE (1024 * 256) +#define VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE (1024 * 256) +#define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4) +#define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL +#define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64) +#define VIRTIO_VSOCK_MAX_TX_BUF_SIZE (1024 * 1024 * 16) +#define VIRTIO_VSOCK_MAX_DGRAM_SIZE (1024 * 64) + +struct vsock_transport_recv_notify_data; +struct vsock_transport_send_notify_data; +struct sockaddr_vm; +struct vsock_sock; + +enum { + VSOCK_VQ_CTRL = 0, + VSOCK_VQ_RX = 1, /* for host to guest data */ + VSOCK_VQ_TX = 2, /* for guest to host data */ + VSOCK_VQ_MAX = 3, +}; + +/* virtio transport socket state */ +struct virtio_transport { + struct virtio_transport_pkt_ops *ops; + struct vsock_sock *vsk; + + u32 buf_size; + u32 buf_size_min; + u32 buf_size_max; + + struct mutex tx_lock; + struct mutex rx_lock; + + struct list_head rx_queue; + u32 rx_bytes; + + /* Protected by trans->tx_lock */ + u32 tx_cnt; + u32 buf_alloc; + u32 peer_fwd_cnt; + u32 peer_buf_alloc; + /* Protected by trans->rx_lock */ + u32 fwd_cnt; + + /* Protected by sk_lock */ + u16 dgram_id; + struct list_head incomplete_dgrams; /* dgram fragments */ +}; + +struct virtio_vsock_pkt { + struct virtio_vsock_hdr hdr; + struct virtio_transport *trans; + struct work_struct work; + struct list_head list; + void *buf; + u32 len; + u32 off; +}; + +struct virtio_vsock_pkt_info { + u32 remote_cid, remote_port; + struct msghdr *msg; + u32 pkt_len; + u16 type; + u16 op; + u32 flags; + u16 dgram_id; + u16 dgram_len; +}; + +struct virtio_transport_pkt_ops { + int (*send_pkt)(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info); +}; + +void virtio_vsock_dumppkt(const char *func, + const struct virtio_vsock_pkt *pkt); + +struct sock * +virtio_transport_get_pending(struct sock *listener, + struct virtio_vsock_pkt *pkt); +struct virtio_vsock_pkt * +virtio_transport_alloc_pkt(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info, + size_t len, + u32 src_cid, + u32 src_port, + u32 dst_cid, + u32 dst_port); +ssize_t +virtio_transport_stream_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, + int type); +int +virtio_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags); + +s64 virtio_transport_stream_has_data(struct vsock_sock *vsk); +s64 virtio_transport_stream_has_space(struct vsock_sock *vsk); + +int virtio_transport_do_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk); +u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk); +u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk); +u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk); +void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val); +void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val); +void virtio_transport_set_max_buffer_size(struct vsock_sock *vs, u64 val); +int +virtio_transport_notify_poll_in(struct vsock_sock *vsk, + size_t target, + bool *data_ready_now); +int +virtio_transport_notify_poll_out(struct vsock_sock *vsk, + size_t target, + bool *space_available_now); + +int virtio_transport_notify_recv_init(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data); +int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data); +int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data); +int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, + size_t target, ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *data); +int virtio_transport_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data); +int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data); +int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data); +int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, + ssize_t written, struct vsock_transport_send_notify_data *data); + +u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk); +bool virtio_transport_stream_is_active(struct vsock_sock *vsk); +bool virtio_transport_stream_allow(u32 cid, u32 port); +int virtio_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr); +bool virtio_transport_dgram_allow(u32 cid, u32 port); + +int virtio_transport_connect(struct vsock_sock *vsk); + +int virtio_transport_shutdown(struct vsock_sock *vsk, int mode); + +void virtio_transport_release(struct vsock_sock *vsk); + +ssize_t +virtio_transport_stream_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len); +int +virtio_transport_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t len); + +void virtio_transport_destruct(struct vsock_sock *vsk); + +void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt); +void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt); +void virtio_transport_inc_tx_pkt(struct virtio_vsock_pkt *pkt); +void virtio_transport_dec_tx_pkt(struct virtio_vsock_pkt *pkt); +u32 virtio_transport_get_credit(struct virtio_transport *trans, u32 wanted); +void virtio_transport_put_credit(struct virtio_transport *trans, u32 credit); +#endif /* _LINUX_VIRTIO_VSOCK_H */ diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index e9eb2d6791b3..a0c8fa2ababf 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -175,8 +175,10 @@ void vsock_insert_connected(struct vsock_sock *vsk); void vsock_remove_bound(struct vsock_sock *vsk); void vsock_remove_connected(struct vsock_sock *vsk); struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); +struct sock *vsock_find_unbound_socket(struct sockaddr_vm *addr); struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, struct sockaddr_vm *dst); void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)); +int vsock_bind_dgram_generic(struct vsock_sock *vsk, struct sockaddr_vm *addr); #endif /* __AF_VSOCK_H__ */ diff --git a/include/uapi/linux/virtio_ids.h b/include/uapi/linux/virtio_ids.h index 77925f587b15..16dcf5d06cd7 100644 --- a/include/uapi/linux/virtio_ids.h +++ b/include/uapi/linux/virtio_ids.h @@ -39,6 +39,7 @@ #define VIRTIO_ID_9P 9 /* 9p virtio console */ #define VIRTIO_ID_RPROC_SERIAL 11 /* virtio remoteproc serial link */ #define VIRTIO_ID_CAIF 12 /* Virtio caif */ +#define VIRTIO_ID_VSOCK 13 /* virtio vsock transport */ #define VIRTIO_ID_GPU 16 /* virtio GPU */ #define VIRTIO_ID_INPUT 18 /* virtio input */ diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h new file mode 100644 index 000000000000..8cf9b5682628 --- /dev/null +++ b/include/uapi/linux/virtio_vsock.h @@ -0,0 +1,89 @@ +/* + * This header, excluding the #ifdef __KERNEL__ part, is BSD licensed so + * anyone can use the definitions to implement compatible drivers/servers: + * + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of IBM nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL IBM OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * Copyright (C) Red Hat, Inc., 2013-2015 + * Copyright (C) Asias He <asias@redhat.com>, 2013 + * Copyright (C) Stefan Hajnoczi <stefanha@redhat.com>, 2015 + */ + +#ifndef _UAPI_LINUX_VIRTIO_VSOCK_H +#define _UAPI_LINUX_VIRTIO_VOSCK_H + +#include <linux/types.h> +#include <linux/virtio_ids.h> +#include <linux/virtio_config.h> + +struct virtio_vsock_config { + __le32 guest_cid; + __le32 max_virtqueue_pairs; +}; + +struct virtio_vsock_hdr { + __le32 src_cid; + __le32 src_port; + __le32 dst_cid; + __le32 dst_port; + __le32 len; + __le16 type; /* enum virtio_vsock_type */ + __le16 op; /* enum virtio_vsock_op */ + __le32 flags; + __le32 buf_alloc; + __le32 fwd_cnt; +}; + +enum virtio_vsock_type { + VIRTIO_VSOCK_TYPE_STREAM = 1, + VIRTIO_VSOCK_TYPE_DGRAM = 2, +}; + +enum virtio_vsock_op { + VIRTIO_VSOCK_OP_INVALID = 0, + + /* Connect operations */ + VIRTIO_VSOCK_OP_REQUEST = 1, + VIRTIO_VSOCK_OP_RESPONSE = 2, + VIRTIO_VSOCK_OP_ACK = 3, + VIRTIO_VSOCK_OP_RST = 4, + VIRTIO_VSOCK_OP_SHUTDOWN = 5, + + /* To send payload */ + VIRTIO_VSOCK_OP_RW = 6, + + /* Tell the peer our credit info */ + VIRTIO_VSOCK_OP_CREDIT_UPDATE = 7, + /* Request the peer to send the credit info to us */ + VIRTIO_VSOCK_OP_CREDIT_REQUEST = 8, +}; + +/* VIRTIO_VSOCK_OP_SHUTDOWN flags values */ +enum virtio_vsock_shutdown { + VIRTIO_VSOCK_SHUTDOWN_RCV = 1, + VIRTIO_VSOCK_SHUTDOWN_SEND = 2, +}; + +#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */ diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig index 14810abedc2e..74e0bc887a33 100644 --- a/net/vmw_vsock/Kconfig +++ b/net/vmw_vsock/Kconfig @@ -26,3 +26,21 @@ config VMWARE_VMCI_VSOCKETS To compile this driver as a module, choose M here: the module will be called vmw_vsock_vmci_transport. If unsure, say N. + +config VIRTIO_VSOCKETS + tristate "virtio transport for Virtual Sockets" + depends on VSOCKETS && VIRTIO + select VIRTIO_VSOCKETS_COMMON + help + This module implements a virtio transport for Virtual Sockets. + + Enable this transport if your Virtual Machine runs on Qemu/KVM. + + To compile this driver as a module, choose M here: the module + will be called virtio_vsock_transport. If unsure, say N. + +config VIRTIO_VSOCKETS_COMMON + tristate + ---help--- + This option is selected by any driver which needs to access + the virtio_vsock. diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile index 2ce52d70f224..cf4c29439081 100644 --- a/net/vmw_vsock/Makefile +++ b/net/vmw_vsock/Makefile @@ -1,5 +1,7 @@ obj-$(CONFIG_VSOCKETS) += vsock.o obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS) += virtio_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += virtio_transport_common.o vsock-y += af_vsock.o vsock_addr.o diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 7fd1220fbfa0..77247a2b670b 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -223,6 +223,17 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) return NULL; } +static struct sock *__vsock_find_unbound_socket(struct sockaddr_vm *addr) +{ + struct vsock_sock *vsk; + + list_for_each_entry(vsk, vsock_unbound_sockets, bound_table) + if (addr->svm_port == vsk->local_addr.svm_port) + return sk_vsock(vsk); + + return NULL; +} + static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, struct sockaddr_vm *dst) { @@ -298,6 +309,21 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr) } EXPORT_SYMBOL_GPL(vsock_find_bound_socket); +struct sock *vsock_find_unbound_socket(struct sockaddr_vm *addr) +{ + struct sock *sk; + + spin_lock_bh(&vsock_table_lock); + sk = __vsock_find_unbound_socket(addr); + if (sk) + sock_hold(sk); + + spin_unlock_bh(&vsock_table_lock); + + return sk; +} +EXPORT_SYMBOL_GPL(vsock_find_unbound_socket); + struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, struct sockaddr_vm *dst) { @@ -532,6 +558,50 @@ static int __vsock_bind_stream(struct vsock_sock *vsk, return 0; } +int vsock_bind_dgram_generic(struct vsock_sock *vsk, struct sockaddr_vm *addr) +{ + static u32 port = LAST_RESERVED_PORT + 1; + struct sockaddr_vm new_addr; + + vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port); + + if (addr->svm_port == VMADDR_PORT_ANY) { + bool found = false; + unsigned int i; + + for (i = 0; i < MAX_PORT_RETRIES; i++) { + if (port <= LAST_RESERVED_PORT) + port = LAST_RESERVED_PORT + 1; + + new_addr.svm_port = port++; + + if (!__vsock_find_unbound_socket(&new_addr)) { + found = true; + break; + } + } + + if (!found) + return -EADDRNOTAVAIL; + } else { + /* If port is in reserved range, ensure caller + * has necessary privileges. + */ + if (addr->svm_port <= LAST_RESERVED_PORT && + !capable(CAP_NET_BIND_SERVICE)) { + return -EACCES; + } + + if (__vsock_find_unbound_socket(&new_addr)) + return -EADDRINUSE; + } + + vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port); + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_bind_dgram_generic); + static int __vsock_bind_dgram(struct vsock_sock *vsk, struct sockaddr_vm *addr) { diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c new file mode 100644 index 000000000000..df65dca55fa1 --- /dev/null +++ b/net/vmw_vsock/virtio_transport.c @@ -0,0 +1,466 @@ +/* + * virtio transport for vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s + * early virtio-vsock proof-of-concept bits. + * + * This work is licensed under the terms of the GNU GPL, version 2. + */ +#include <linux/spinlock.h> +#include <linux/module.h> +#include <linux/list.h> +#include <linux/virtio.h> +#include <linux/virtio_ids.h> +#include <linux/virtio_config.h> +#include <linux/virtio_vsock.h> +#include <net/sock.h> +#include <linux/mutex.h> +#include <net/af_vsock.h> + +static struct workqueue_struct *virtio_vsock_workqueue; +static struct virtio_vsock *the_virtio_vsock; +static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ +static void virtio_vsock_rx_fill(struct virtio_vsock *vsock); + +struct virtio_vsock { + /* Virtio device */ + struct virtio_device *vdev; + /* Virtio virtqueue */ + struct virtqueue *vqs[VSOCK_VQ_MAX]; + /* Wait queue for send pkt */ + wait_queue_head_t queue_wait; + /* Work item to send pkt */ + struct work_struct tx_work; + /* Work item to recv pkt */ + struct work_struct rx_work; + /* Mutex to protect send pkt*/ + struct mutex tx_lock; + /* Mutex to protect recv pkt*/ + struct mutex rx_lock; + /* Number of recv buffers */ + int rx_buf_nr; + /* Number of max recv buffers */ + int rx_buf_max_nr; + /* Used for global tx buf limitation */ + u32 total_tx_buf; + /* Guest context id, just like guest ip address */ + u32 guest_cid; +}; + +static struct virtio_vsock *virtio_vsock_get(void) +{ + return the_virtio_vsock; +} + +static u32 virtio_transport_get_local_cid(void) +{ + struct virtio_vsock *vsock = virtio_vsock_get(); + + return vsock->guest_cid; +} + +static int +virtio_transport_send_pkt(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info) +{ + u32 src_cid, src_port, dst_cid, dst_port; + int ret, in_sg = 0, out_sg = 0; + struct virtio_transport *trans; + struct virtio_vsock_pkt *pkt; + struct virtio_vsock *vsock; + struct scatterlist hdr, buf, *sgs[2]; + struct virtqueue *vq; + u32 pkt_len = info->pkt_len; + DEFINE_WAIT(wait); + + vsock = virtio_vsock_get(); + if (!vsock) + return -ENODEV; + + src_cid = virtio_transport_get_local_cid(); + src_port = vsk->local_addr.svm_port; + if (!info->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = info->remote_cid; + dst_port = info->remote_port; + } + + trans = vsk->trans; + vq = vsock->vqs[VSOCK_VQ_TX]; + + if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) + pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; + pkt_len = virtio_transport_get_credit(trans, pkt_len); + /* Do not send zero length OP_RW pkt*/ + if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) + return pkt_len; + + /* Respect global tx buf limitation */ + mutex_lock(&vsock->tx_lock); + while (pkt_len + vsock->total_tx_buf > VIRTIO_VSOCK_MAX_TX_BUF_SIZE) { + prepare_to_wait_exclusive(&vsock->queue_wait, &wait, + TASK_UNINTERRUPTIBLE); + mutex_unlock(&vsock->tx_lock); + schedule(); + mutex_lock(&vsock->tx_lock); + finish_wait(&vsock->queue_wait, &wait); + } + vsock->total_tx_buf += pkt_len; + mutex_unlock(&vsock->tx_lock); + + pkt = virtio_transport_alloc_pkt(vsk, info, pkt_len, + src_cid, src_port, + dst_cid, dst_port); + if (!pkt) { + mutex_lock(&vsock->tx_lock); + vsock->total_tx_buf -= pkt_len; + mutex_unlock(&vsock->tx_lock); + virtio_transport_put_credit(trans, pkt_len); + return -ENOMEM; + } + + pr_debug("%s:info->pkt_len= %d\n", __func__, info->pkt_len); + + /* Will be released in virtio_transport_send_pkt_work */ + sock_hold(&trans->vsk->sk); + virtio_transport_inc_tx_pkt(pkt); + + /* Put pkt in the virtqueue */ + sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); + sgs[out_sg++] = &hdr; + if (info->msg && info->pkt_len > 0) { + sg_init_one(&buf, pkt->buf, pkt->len); + sgs[out_sg++] = &buf; + } + + mutex_lock(&vsock->tx_lock); + while ((ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, + GFP_KERNEL)) < 0) { + prepare_to_wait_exclusive(&vsock->queue_wait, &wait, + TASK_UNINTERRUPTIBLE); + mutex_unlock(&vsock->tx_lock); + schedule(); + mutex_lock(&vsock->tx_lock); + finish_wait(&vsock->queue_wait, &wait); + } + virtqueue_kick(vq); + mutex_unlock(&vsock->tx_lock); + + return pkt_len; +} + +static struct virtio_transport_pkt_ops virtio_ops = { + .send_pkt = virtio_transport_send_pkt, +}; + +static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) +{ + int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; + struct virtio_vsock_pkt *pkt; + struct scatterlist hdr, buf, *sgs[2]; + struct virtqueue *vq; + int ret; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + do { + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) { + pr_debug("%s: fail to allocate pkt\n", __func__); + goto out; + } + + /* TODO: use mergeable rx buffer */ + pkt->buf = kmalloc(buf_len, GFP_KERNEL); + if (!pkt->buf) { + pr_debug("%s: fail to allocate pkt->buf\n", __func__); + goto err; + } + + sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); + sgs[0] = &hdr; + + sg_init_one(&buf, pkt->buf, buf_len); + sgs[1] = &buf; + ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL); + if (ret) + goto err; + vsock->rx_buf_nr++; + } while (vq->num_free); + if (vsock->rx_buf_nr > vsock->rx_buf_max_nr) + vsock->rx_buf_max_nr = vsock->rx_buf_nr; +out: + virtqueue_kick(vq); + return; +err: + virtqueue_kick(vq); + virtio_transport_free_pkt(pkt); + return; +} + +static void virtio_transport_send_pkt_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, tx_work); + struct virtio_vsock_pkt *pkt; + bool added = false; + struct virtqueue *vq; + unsigned int len; + struct sock *sk; + + vq = vsock->vqs[VSOCK_VQ_TX]; + mutex_lock(&vsock->tx_lock); + do { + virtqueue_disable_cb(vq); + while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) { + sk = &pkt->trans->vsk->sk; + virtio_transport_dec_tx_pkt(pkt); + /* Release refcnt taken in virtio_transport_send_pkt */ + sock_put(sk); + vsock->total_tx_buf -= pkt->len; + virtio_transport_free_pkt(pkt); + added = true; + } + } while (!virtqueue_enable_cb(vq)); + mutex_unlock(&vsock->tx_lock); + + if (added) + wake_up(&vsock->queue_wait); +} + +static void virtio_transport_recv_pkt_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, rx_work); + struct virtio_vsock_pkt *pkt; + struct virtqueue *vq; + unsigned int len; + + vq = vsock->vqs[VSOCK_VQ_RX]; + mutex_lock(&vsock->rx_lock); + do { + virtqueue_disable_cb(vq); + while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) { + pkt->len = len; + virtio_transport_recv_pkt(pkt); + vsock->rx_buf_nr--; + } + } while (!virtqueue_enable_cb(vq)); + + if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); +} + +static void virtio_vsock_ctrl_done(struct virtqueue *vq) +{ +} + +static void virtio_vsock_tx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->tx_work); +} + +static void virtio_vsock_rx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static int +virtio_transport_socket_init(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + struct virtio_transport *trans; + int ret; + + ret = virtio_transport_do_socket_init(vsk, psk); + if (ret) + return ret; + + trans = vsk->trans; + trans->ops = &virtio_ops; + return ret; +} + +static struct vsock_transport virtio_transport = { + .get_local_cid = virtio_transport_get_local_cid, + + .init = virtio_transport_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + + .dgram_bind = virtio_transport_dgram_bind, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + + .set_buffer_size = virtio_transport_set_buffer_size, + .set_min_buffer_size = virtio_transport_set_min_buffer_size, + .set_max_buffer_size = virtio_transport_set_max_buffer_size, + .get_buffer_size = virtio_transport_get_buffer_size, + .get_min_buffer_size = virtio_transport_get_min_buffer_size, + .get_max_buffer_size = virtio_transport_get_max_buffer_size, +}; + +static int virtio_vsock_probe(struct virtio_device *vdev) +{ + vq_callback_t *callbacks[] = { + virtio_vsock_ctrl_done, + virtio_vsock_rx_done, + virtio_vsock_tx_done, + }; + const char *names[] = { + "ctrl", + "rx", + "tx", + }; + struct virtio_vsock *vsock = NULL; + u32 guest_cid; + int ret; + + ret = mutex_lock_interruptible(&the_virtio_vsock_mutex); + if (ret) + return ret; + + /* Only one virtio-vsock device per guest is supported */ + if (the_virtio_vsock) { + ret = -EBUSY; + goto out; + } + + vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); + if (!vsock) { + ret = -ENOMEM; + goto out; + } + + vsock->vdev = vdev; + + ret = vsock->vdev->config->find_vqs(vsock->vdev, VSOCK_VQ_MAX, + vsock->vqs, callbacks, names); + if (ret < 0) + goto out; + + vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), + &guest_cid, sizeof(guest_cid)); + vsock->guest_cid = le32_to_cpu(guest_cid); + pr_debug("%s:guest_cid=%d\n", __func__, vsock->guest_cid); + + ret = vsock_core_init(&virtio_transport); + if (ret < 0) + goto out_vqs; + + vsock->rx_buf_nr = 0; + vsock->rx_buf_max_nr = 0; + + vdev->priv = the_virtio_vsock = vsock; + init_waitqueue_head(&vsock->queue_wait); + mutex_init(&vsock->tx_lock); + mutex_init(&vsock->rx_lock); + INIT_WORK(&vsock->rx_work, virtio_transport_recv_pkt_work); + INIT_WORK(&vsock->tx_work, virtio_transport_send_pkt_work); + + mutex_lock(&vsock->rx_lock); + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); + + mutex_unlock(&the_virtio_vsock_mutex); + return 0; + +out_vqs: + vsock->vdev->config->del_vqs(vsock->vdev); +out: + kfree(vsock); + mutex_unlock(&the_virtio_vsock_mutex); + return ret; +} + +static void virtio_vsock_remove(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + + mutex_lock(&the_virtio_vsock_mutex); + the_virtio_vsock = NULL; + vsock_core_exit(); + mutex_unlock(&the_virtio_vsock_mutex); + + kfree(vsock); +} + +static struct virtio_device_id id_table[] = { + { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID }, + { 0 }, +}; + +static unsigned int features[] = { +}; + +static struct virtio_driver virtio_vsock_driver = { + .feature_table = features, + .feature_table_size = ARRAY_SIZE(features), + .driver.name = KBUILD_MODNAME, + .driver.owner = THIS_MODULE, + .id_table = id_table, + .probe = virtio_vsock_probe, + .remove = virtio_vsock_remove, +}; + +static int __init virtio_vsock_init(void) +{ + int ret; + + virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); + if (!virtio_vsock_workqueue) + return -ENOMEM; + ret = register_virtio_driver(&virtio_vsock_driver); + if (ret) + destroy_workqueue(virtio_vsock_workqueue); + return ret; +} + +static void __exit virtio_vsock_exit(void) +{ + unregister_virtio_driver(&virtio_vsock_driver); + destroy_workqueue(virtio_vsock_workqueue); +} + +module_init(virtio_vsock_init); +module_exit(virtio_vsock_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("virtio transport for vsock"); +MODULE_DEVICE_TABLE(virtio, id_table); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c new file mode 100644 index 000000000000..28f790da6f15 --- /dev/null +++ b/net/vmw_vsock/virtio_transport_common.c @@ -0,0 +1,1272 @@ +/* + * common code for virtio vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * This work is licensed under the terms of the GNU GPL, version 2. + */ +#include <linux/module.h> +#include <linux/ctype.h> +#include <linux/list.h> +#include <linux/virtio.h> +#include <linux/virtio_ids.h> +#include <linux/virtio_config.h> +#include <linux/virtio_vsock.h> +#include <linux/random.h> +#include <linux/cryptohash.h> + +#include <net/sock.h> +#include <net/af_vsock.h> + +#define COOKIEBITS 24 +#define COOKIEMASK (((u32)1 << COOKIEBITS) - 1) +#define VSOCK_TIMEOUT_INIT 4 + +#define SHA_MESSAGE_WORDS 16 +#define SHA_VSOCK_WORDS 5 + +static u32 vsockcookie_secret[2][SHA_MESSAGE_WORDS - SHA_VSOCK_WORDS + + SHA_DIGEST_WORDS]; + +static DEFINE_PER_CPU(__u32[SHA_MESSAGE_WORDS + SHA_DIGEST_WORDS + + SHA_WORKSPACE_WORDS], vsock_cookie_scratch); + +static u32 cookie_hash(u32 saddr, u32 daddr, u16 sport, u16 dport, + u32 count, int c) +{ + __u32 *tmp = this_cpu_ptr(vsock_cookie_scratch); + + memcpy(tmp + SHA_VSOCK_WORDS, vsockcookie_secret[c], + sizeof(vsockcookie_secret[c])); + tmp[0] = saddr; + tmp[1] = daddr; + tmp[2] = sport; + tmp[3] = dport; + tmp[4] = count; + sha_transform(tmp + SHA_MESSAGE_WORDS, (__u8 *)tmp, + tmp + SHA_MESSAGE_WORDS + SHA_DIGEST_WORDS); + + return tmp[17]; +} + +static u32 +virtio_vsock_secure_cookie(u32 saddr, u32 daddr, u32 sport, u32 dport, + u32 count) +{ + u32 h1, h2; + + h1 = cookie_hash(saddr, daddr, sport, dport, 0, 0); + h2 = cookie_hash(saddr, daddr, sport, dport, count, 1); + + return h1 + (count << COOKIEBITS) + (h2 & COOKIEMASK); +} + +static u32 +virtio_vsock_check_cookie(u32 saddr, u32 daddr, u32 sport, u32 dport, + u32 count, u32 cookie, u32 maxdiff) +{ + u32 diff; + u32 ret; + + cookie -= cookie_hash(saddr, daddr, sport, dport, 0, 0); + + diff = (count - (cookie >> COOKIEBITS)) & ((u32)-1 >> COOKIEBITS); + pr_debug("%s: diff=%x\n", __func__, diff); + if (diff >= maxdiff) + return (u32)-1; + + ret = (cookie - + cookie_hash(saddr, daddr, sport, dport, count - diff, 1)) + & COOKIEMASK; + pr_debug("%s: ret=%x\n", __func__, diff); + + return ret; +} + +void virtio_vsock_dumppkt(const char *func, const struct virtio_vsock_pkt *pkt) +{ + pr_debug("%s: pkt=%p, op=%d, len=%d, %d:%d---%d:%d, len=%d\n", + func, pkt, + le16_to_cpu(pkt->hdr.op), + le32_to_cpu(pkt->hdr.len), + le32_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.src_port), + le32_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.dst_port), + pkt->len); +} +EXPORT_SYMBOL_GPL(virtio_vsock_dumppkt); + +struct virtio_vsock_pkt * +virtio_transport_alloc_pkt(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info, + size_t len, + u32 src_cid, + u32 src_port, + u32 dst_cid, + u32 dst_port) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt *pkt; + int err; + + BUG_ON(!trans); + + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return NULL; + + pkt->hdr.type = cpu_to_le16(info->type); + pkt->hdr.op = cpu_to_le16(info->op); + pkt->hdr.src_cid = cpu_to_le32(src_cid); + pkt->hdr.src_port = cpu_to_le32(src_port); + pkt->hdr.dst_cid = cpu_to_le32(dst_cid); + pkt->hdr.dst_port = cpu_to_le32(dst_port); + pkt->hdr.flags = cpu_to_le32(info->flags); + pkt->len = len; + pkt->trans = trans; + if (info->type == VIRTIO_VSOCK_TYPE_DGRAM) + pkt->hdr.len = cpu_to_le32(len + (info->dgram_len << 16)); + else if (info->type == VIRTIO_VSOCK_TYPE_STREAM) + pkt->hdr.len = cpu_to_le32(len); + + if (info->msg && len > 0) { + pkt->buf = kmalloc(len, GFP_KERNEL); + if (!pkt->buf) + goto out_pkt; + err = memcpy_from_msg(pkt->buf, info->msg, len); + if (err) + goto out; + } + + return pkt; + +out: + kfree(pkt->buf); +out_pkt: + kfree(pkt); + return NULL; +} +EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt); + +struct sock * +virtio_transport_get_pending(struct sock *listener, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vpending; + struct sockaddr_vm src; + struct sockaddr_vm dst; + struct sock *pending; + + vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); + vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port)); + + vlistener = vsock_sk(listener); + list_for_each_entry(vpending, &vlistener->pending_links, + pending_links) { + if (vsock_addr_equals_addr(&src, &vpending->remote_addr) && + vsock_addr_equals_addr(&dst, &vpending->local_addr)) { + pending = sk_vsock(vpending); + sock_hold(pending); + return pending; + } + } + + return NULL; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_pending); + +static void virtio_transport_inc_rx_pkt(struct virtio_vsock_pkt *pkt) +{ + pkt->trans->rx_bytes += pkt->len; +} + +static void virtio_transport_dec_rx_pkt(struct virtio_vsock_pkt *pkt) +{ + pkt->trans->rx_bytes -= pkt->len; + pkt->trans->fwd_cnt += pkt->len; +} + +void virtio_transport_inc_tx_pkt(struct virtio_vsock_pkt *pkt) +{ + mutex_lock(&pkt->trans->tx_lock); + pkt->hdr.fwd_cnt = cpu_to_le32(pkt->trans->fwd_cnt); + pkt->hdr.buf_alloc = cpu_to_le32(pkt->trans->buf_alloc); + mutex_unlock(&pkt->trans->tx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); + +void virtio_transport_dec_tx_pkt(struct virtio_vsock_pkt *pkt) +{ +} +EXPORT_SYMBOL_GPL(virtio_transport_dec_tx_pkt); + +u32 virtio_transport_get_credit(struct virtio_transport *trans, u32 credit) +{ + u32 ret; + + mutex_lock(&trans->tx_lock); + ret = trans->peer_buf_alloc - (trans->tx_cnt - trans->peer_fwd_cnt); + if (ret > credit) + ret = credit; + trans->tx_cnt += ret; + mutex_unlock(&trans->tx_lock); + + pr_debug("%s: ret=%d, buf_alloc=%d, peer_buf_alloc=%d," + "tx_cnt=%d, fwd_cnt=%d, peer_fwd_cnt=%d\n", __func__, + ret, trans->buf_alloc, trans->peer_buf_alloc, + trans->tx_cnt, trans->fwd_cnt, trans->peer_fwd_cnt); + + return ret; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_credit); + +void virtio_transport_put_credit(struct virtio_transport *trans, u32 credit) +{ + mutex_lock(&trans->tx_lock); + trans->tx_cnt -= credit; + mutex_unlock(&trans->tx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_put_credit); + +static int virtio_transport_send_credit_update(struct vsock_sock *vsk, int type, struct virtio_vsock_hdr *hdr) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, + .type = type, + }; + + if (hdr && type == VIRTIO_VSOCK_TYPE_DGRAM) { + info.remote_cid = le32_to_cpu(hdr->src_cid); + info.remote_port = le32_to_cpu(hdr->src_port); + } + + pr_debug("%s: sk=%p send_credit_update\n", __func__, vsk); + return trans->ops->send_pkt(vsk, &info); +} + +static int virtio_transport_send_credit_request(struct vsock_sock *vsk, int type) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_CREDIT_REQUEST, + .type = type, + }; + + pr_debug("%s: sk=%p send_credit_request\n", __func__, vsk); + return trans->ops->send_pkt(vsk, &info); +} + +static ssize_t +virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt *pkt; + size_t bytes, total = 0; + int err = -EFAULT; + + mutex_lock(&trans->rx_lock); + while (total < len && trans->rx_bytes > 0 && + !list_empty(&trans->rx_queue)) { + pkt = list_first_entry(&trans->rx_queue, + struct virtio_vsock_pkt, list); + + bytes = len - total; + if (bytes > pkt->len - pkt->off) + bytes = pkt->len - pkt->off; + + err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); + if (err) + goto out; + total += bytes; + pkt->off += bytes; + if (pkt->off == pkt->len) { + virtio_transport_dec_rx_pkt(pkt); + list_del(&pkt->list); + virtio_transport_free_pkt(pkt); + } + } + mutex_unlock(&trans->rx_lock); + + /* Send a credit pkt to peer */ + virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, + NULL); + + return total; + +out: + mutex_unlock(&trans->rx_lock); + if (total) + err = total; + return err; +} + +ssize_t +virtio_transport_stream_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + if (flags & MSG_PEEK) + return -EOPNOTSUPP; + + return virtio_transport_stream_do_dequeue(vsk, msg, len); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); + +struct dgram_skb { + struct list_head list; + struct sk_buff *skb; + u16 id; +}; + +static struct dgram_skb *dgram_id_to_skb(struct virtio_transport *trans, + u16 id) +{ + struct dgram_skb *dgram_skb; + + list_for_each_entry(dgram_skb, &trans->incomplete_dgrams, list) { + if (dgram_skb->id == id) + return dgram_skb; + } + + return NULL; +} + +static void +virtio_transport_recv_dgram(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct sk_buff *skb = NULL; + struct vsock_sock *vsk; + struct virtio_transport *trans; + size_t size; + u16 dgram_id, pkt_off, dgram_len, pkt_len; + u32 flags, len; + struct dgram_skb *dgram_skb; + + vsk = vsock_sk(sk); + trans = vsk->trans; + + /* len: dgram_len | pkt_len */ + len = le32_to_cpu(pkt->hdr.len); + dgram_len = len >> 16; + pkt_len = len & 0xFFFF; + + /* flags: dgram_id | pkt_off */ + flags = le32_to_cpu(pkt->hdr.flags); + dgram_id = flags >> 16; + pkt_off = flags & 0xFFFF; + + pr_debug("%s: dgram_len=%d, pkt_len=%d, id=%d, off=%d\n", __func__, + dgram_len, pkt_len, dgram_id, pkt_off); + + dgram_skb = dgram_id_to_skb(trans, dgram_id); + if (dgram_skb) { + /* This pkt is for a existing dgram */ + skb = dgram_skb->skb; + pr_debug("%s:found skb\n", __func__); + } + + /* Packet payload must be within datagram bounds */ + if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) + goto drop; + if (pkt_len > dgram_len) + goto drop; + if (pkt_off > dgram_len) + goto drop; + if (dgram_len - pkt_off < pkt_len) + goto drop; + + if (!skb) { + /* This pkt is for a new dgram */ + pr_debug("%s:create skb\n", __func__); + + size = sizeof(pkt->hdr) + dgram_len; + /* Attach the packet to the socket's receive queue as an sk_buff. */ + dgram_skb = kzalloc(sizeof(struct dgram_skb), GFP_ATOMIC); + if (!dgram_skb) + goto drop; + + skb = alloc_skb(size, GFP_ATOMIC); + if (!skb) { + kfree(dgram_skb); + dgram_skb = NULL; + goto drop; + } + dgram_skb->id = dgram_id; + dgram_skb->skb = skb; + list_add_tail(&dgram_skb->list, &trans->incomplete_dgrams); + + /* sk_receive_skb() will do a sock_put(), so hold here. */ + sock_hold(sk); + skb_put(skb, size); + memcpy(skb->data, &pkt->hdr, sizeof(pkt->hdr)); + } + + memcpy(skb->data + sizeof(pkt->hdr) + pkt_off, pkt->buf, pkt_len); + + pr_debug("%s:C, off=%d, pkt_len=%d, dgram_len=%d\n", __func__, + pkt_off, pkt_len, dgram_len); + + /* We are done with this dgram */ + if (pkt_off + pkt_len == dgram_len) { + pr_debug("%s:dgram_id=%d is done\n", __func__, dgram_id); + list_del(&dgram_skb->list); + kfree(dgram_skb); + sk_receive_skb(sk, skb, 0); + } + virtio_transport_free_pkt(pkt); + return; + +drop: + if (dgram_skb) { + list_del(&dgram_skb->list); + kfree(dgram_skb); + kfree_skb(skb); + sock_put(sk); + } + virtio_transport_free_pkt(pkt); +} + +int +virtio_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + struct virtio_vsock_hdr *hdr; + struct sk_buff *skb; + int noblock; + int err; + int dgram_len; + + noblock = flags & MSG_DONTWAIT; + + if (flags & MSG_OOB || flags & MSG_ERRQUEUE) + return -EOPNOTSUPP; + + /* Retrieve the head sk_buff from the socket's receive queue. */ + err = 0; + skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err); + if (err) + return err; + if (!skb) + return -EAGAIN; + + hdr = (struct virtio_vsock_hdr *)skb->data; + if (!hdr) + goto out; + + dgram_len = le32_to_cpu(hdr->len) >> 16; + /* Place the datagram payload in the user's iovec. */ + err = skb_copy_datagram_msg(skb, sizeof(*hdr), msg, dgram_len); + if (err) + goto out; + + if (msg->msg_name) { + /* Provide the address of the sender. */ + DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name); + vsock_addr_init(vm_addr, le32_to_cpu(hdr->src_cid), le32_to_cpu(hdr->src_port)); + msg->msg_namelen = sizeof(*vm_addr); + } + err = dgram_len; + + /* Send a credit pkt to peer */ + virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_DGRAM, hdr); + + pr_debug("%s:done, recved =%d\n", __func__, dgram_len); +out: + skb_free_datagram(&vsk->sk, skb); + return err; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); + +s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + s64 bytes; + + mutex_lock(&trans->rx_lock); + bytes = trans->rx_bytes; + mutex_unlock(&trans->rx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); + +static s64 virtio_transport_has_space(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + s64 bytes; + + bytes = trans->peer_buf_alloc - (trans->tx_cnt - trans->peer_fwd_cnt); + if (bytes < 0) + bytes = 0; + + return bytes; +} + +s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + s64 bytes; + + mutex_lock(&trans->tx_lock); + bytes = virtio_transport_has_space(vsk); + mutex_unlock(&trans->tx_lock); + + pr_debug("%s: bytes=%lld\n", __func__, bytes); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); + +int virtio_transport_do_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk) +{ + struct virtio_transport *trans; + + trans = kzalloc(sizeof(*trans), GFP_KERNEL); + if (!trans) + return -ENOMEM; + + vsk->trans = trans; + trans->vsk = vsk; + if (psk) { + struct virtio_transport *ptrans = psk->trans; + trans->buf_size = ptrans->buf_size; + trans->buf_size_min = ptrans->buf_size_min; + trans->buf_size_max = ptrans->buf_size_max; + trans->peer_buf_alloc = ptrans->peer_buf_alloc; + } else { + trans->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; + trans->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; + trans->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; + } + + trans->buf_alloc = trans->buf_size; + + pr_debug("%s: trans->buf_alloc=%d\n", __func__, trans->buf_alloc); + + mutex_init(&trans->rx_lock); + mutex_init(&trans->tx_lock); + INIT_LIST_HEAD(&trans->rx_queue); + INIT_LIST_HEAD(&trans->incomplete_dgrams); + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); + +u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + + return trans->buf_size; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); + +u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + + return trans->buf_size_min; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); + +u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + + return trans->buf_size_max; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); + +void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) +{ + struct virtio_transport *trans = vsk->trans; + + if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) + val = VIRTIO_VSOCK_MAX_BUF_SIZE; + if (val < trans->buf_size_min) + trans->buf_size_min = val; + if (val > trans->buf_size_max) + trans->buf_size_max = val; + trans->buf_size = val; + trans->buf_alloc = val; +} +EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); + +void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) +{ + struct virtio_transport *trans = vsk->trans; + + if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) + val = VIRTIO_VSOCK_MAX_BUF_SIZE; + if (val > trans->buf_size) + trans->buf_size = val; + trans->buf_size_min = val; +} +EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); + +void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) +{ + struct virtio_transport *trans = vsk->trans; + + if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) + val = VIRTIO_VSOCK_MAX_BUF_SIZE; + if (val < trans->buf_size) + trans->buf_size = val; + trans->buf_size_max = val; +} +EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); + +int +virtio_transport_notify_poll_in(struct vsock_sock *vsk, + size_t target, + bool *data_ready_now) +{ + if (vsock_stream_has_data(vsk)) + *data_ready_now = true; + else + *data_ready_now = false; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); + +int +virtio_transport_notify_poll_out(struct vsock_sock *vsk, + size_t target, + bool *space_avail_now) +{ + s64 free_space; + + free_space = vsock_stream_has_space(vsk); + if (free_space > 0) + *space_avail_now = true; + else if (free_space == 0) + *space_avail_now = false; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); + +int virtio_transport_notify_recv_init(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); + +int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); + +int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); + +int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, + size_t target, ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); + +int virtio_transport_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); + +int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); + +int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); + +int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, + ssize_t written, struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); + +u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + + return trans->buf_size; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); + +bool virtio_transport_stream_is_active(struct vsock_sock *vsk) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); + +bool virtio_transport_stream_allow(u32 cid, u32 port) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); + +int virtio_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + return vsock_bind_dgram_generic(vsk, addr); +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); + +bool virtio_transport_dgram_allow(u32 cid, u32 port) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); + +int virtio_transport_connect(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_REQUEST, + .type = VIRTIO_VSOCK_TYPE_STREAM, + }; + + pr_debug("%s: vsk=%p send_request\n", __func__, vsk); + return trans->ops->send_pkt(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_connect); + +int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_SHUTDOWN, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .flags = (mode & RCV_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | + (mode & SEND_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_SEND : 0), + }; + + pr_debug("%s: vsk=%p: send_shutdown\n", __func__, vsk); + return trans->ops->send_pkt(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_shutdown); + +void virtio_transport_release(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + struct sock *sk = &vsk->sk; + struct dgram_skb *dgram_skb; + struct dgram_skb *dgram_skb_tmp; + + pr_debug("%s: vsk=%p\n", __func__, vsk); + + /* Tell other side to terminate connection */ + if (sk->sk_type == SOCK_STREAM && sk->sk_state == SS_CONNECTED) { + virtio_transport_shutdown(vsk, SHUTDOWN_MASK); + } + + /* Free incomplete dgrams */ + lock_sock(sk); + list_for_each_entry_safe(dgram_skb, dgram_skb_tmp, + &trans->incomplete_dgrams, list) { + list_del(&dgram_skb->list); + kfree_skb(dgram_skb->skb); + kfree(dgram_skb); + sock_put(sk); /* held in virtio_transport_recv_dgram() */ + } + release_sock(sk); +} +EXPORT_SYMBOL_GPL(virtio_transport_release); + +int +virtio_transport_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t dgram_len) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RW, + .type = VIRTIO_VSOCK_TYPE_DGRAM, + .msg = msg, + }; + size_t total_written = 0, pkt_off = 0, written; + u16 dgram_id; + + /* The max size of a single dgram we support is 64KB */ + if (dgram_len > VIRTIO_VSOCK_MAX_DGRAM_SIZE) + return -EMSGSIZE; + + info.dgram_len = dgram_len; + vsk->remote_addr = *remote_addr; + + dgram_id = trans->dgram_id++; + + /* TODO: To optimize, if we have enough credit to send the pkt already, + * do not ask the peer to send credit to use */ + virtio_transport_send_credit_request(vsk, VIRTIO_VSOCK_TYPE_DGRAM); + + while (total_written < dgram_len) { + info.pkt_len = dgram_len - total_written; + info.flags = dgram_id << 16 | pkt_off; + written = trans->ops->send_pkt(vsk, &info); + if (written < 0) + return -ENOMEM; + if (written == 0) { + /* TODO: if written = 0, we need a sleep & wakeup + * instead of sleep */ + pr_debug("%s: SHOULD WAIT written==0", __func__); + msleep(10); + } + total_written += written; + pkt_off += written; + pr_debug("%s:id=%d, dgram_len=%zu, off=%zu, total_written=%zu, written=%zu\n", + __func__, dgram_id, dgram_len, pkt_off, total_written, written); + } + + return dgram_len; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); + +ssize_t +virtio_transport_stream_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RW, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .msg = msg, + .pkt_len = len, + }; + + return trans->ops->send_pkt(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); + +void virtio_transport_destruct(struct vsock_sock *vsk) +{ + struct virtio_transport *trans = vsk->trans; + + pr_debug("%s: vsk=%p\n", __func__, vsk); + kfree(trans); +} +EXPORT_SYMBOL_GPL(virtio_transport_destruct); + +static int virtio_transport_send_ack(struct vsock_sock *vsk, u32 cookie) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_ACK, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .flags = cpu_to_le32(cookie), + }; + + pr_debug("%s: sk=%p send_offer\n", __func__, vsk); + return trans->ops->send_pkt(vsk, &info); +} + +static int virtio_transport_send_reset(struct vsock_sock *vsk, + struct virtio_vsock_pkt *pkt) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .type = VIRTIO_VSOCK_TYPE_STREAM, + }; + + pr_debug("%s\n", __func__); + + /* Send RST only if the original pkt is not a RST pkt */ + if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) + return 0; + + return trans->ops->send_pkt(vsk, &info); +} + +static int +virtio_transport_recv_connecting(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + int err; + int skerr; + u32 cookie; + + pr_debug("%s: vsk=%p\n", __func__, vsk); + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_RESPONSE: + cookie = le32_to_cpu(pkt->hdr.flags); + pr_debug("%s: got RESPONSE and send ACK, cookie=%x\n", __func__, cookie); + err = virtio_transport_send_ack(vsk, cookie); + if (err < 0) { + skerr = -err; + goto destroy; + } + sk->sk_state = SS_CONNECTED; + sk->sk_socket->state = SS_CONNECTED; + vsock_insert_connected(vsk); + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_INVALID: + pr_debug("%s: got invalid\n", __func__); + break; + case VIRTIO_VSOCK_OP_RST: + pr_debug("%s: got rst\n", __func__); + skerr = ECONNRESET; + err = 0; + goto destroy; + default: + pr_debug("%s: got def\n", __func__); + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + return 0; + +destroy: + virtio_transport_send_reset(vsk, pkt); + sk->sk_state = SS_UNCONNECTED; + sk->sk_err = skerr; + sk->sk_error_report(sk); + return err; +} + +static int +virtio_transport_recv_connected(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_transport *trans = vsk->trans; + int err = 0; + + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_RW: + pkt->len = le32_to_cpu(pkt->hdr.len); + pkt->off = 0; + pkt->trans = trans; + + mutex_lock(&trans->rx_lock); + virtio_transport_inc_rx_pkt(pkt); + list_add_tail(&pkt->list, &trans->rx_queue); + mutex_unlock(&trans->rx_lock); + + sk->sk_data_ready(sk); + return err; + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + sk->sk_write_space(sk); + break; + case VIRTIO_VSOCK_OP_SHUTDOWN: + pr_debug("%s: got shutdown\n", __func__); + if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) + vsk->peer_shutdown |= RCV_SHUTDOWN; + if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) + vsk->peer_shutdown |= SEND_SHUTDOWN; + if (le32_to_cpu(pkt->hdr.flags)) + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_RST: + pr_debug("%s: got rst\n", __func__); + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = SS_DISCONNECTING; + sk->sk_state_change(sk); + break; + default: + err = -EINVAL; + break; + } + + virtio_transport_free_pkt(pkt); + return err; +} + +static int +virtio_transport_send_response(struct vsock_sock *vsk, + struct virtio_vsock_pkt *pkt) +{ + struct virtio_transport *trans = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RESPONSE, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .remote_cid = le32_to_cpu(pkt->hdr.src_cid), + .remote_port = le32_to_cpu(pkt->hdr.src_port), + }; + u32 cookie; + + cookie = virtio_vsock_secure_cookie(le32_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.src_port), + le32_to_cpu(pkt->hdr.dst_port), + jiffies / (HZ * 60)); + info.flags = cpu_to_le32(cookie); + + pr_debug("%s: send_response, cookie=%x\n", __func__, le32_to_cpu(cookie)); + + return trans->ops->send_pkt(vsk, &info); +} + +/* Handle server socket */ +static int +virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_sock *vpending; + struct sock *pending; + int err; + u32 cookie; + + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_REQUEST: + err = virtio_transport_send_response(vsk, pkt); + if (err < 0) { + // FIXME vsk should be vpending + virtio_transport_send_reset(vsk, pkt); + return err; + } + break; + case VIRTIO_VSOCK_OP_ACK: + cookie = le32_to_cpu(pkt->hdr.flags); + err = virtio_vsock_check_cookie(le32_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.src_port), + le32_to_cpu(pkt->hdr.dst_port), + jiffies / (HZ * 60), + le32_to_cpu(pkt->hdr.flags), + VSOCK_TIMEOUT_INIT); + pr_debug("%s: cookie=%x, err=%d\n", __func__, cookie, err); + if (err) + return err; + + /* So no pending socket are responsible for this pkt, create one */ + pr_debug("%s: create pending\n", __func__); + pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, + sk->sk_type, 0); + if (!pending) { + virtio_transport_send_reset(vsk, pkt); + return -ENOMEM; + } + sk->sk_ack_backlog++; + pending->sk_state = SS_CONNECTING; + + vpending = vsock_sk(pending); + vsock_addr_init(&vpending->local_addr, le32_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.dst_port)); + vsock_addr_init(&vpending->remote_addr, le32_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.src_port)); + vsock_add_pending(sk, pending); + + pr_debug("%s: get pending\n", __func__); + pending = virtio_transport_get_pending(sk, pkt); + vpending = vsock_sk(pending); + lock_sock(pending); + switch (pending->sk_state) { + case SS_CONNECTING: + if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_ACK) { + pr_debug("%s: op=%d != OP_ACK\n", __func__, + le16_to_cpu(pkt->hdr.op)); + virtio_transport_send_reset(vpending, pkt); + pending->sk_err = EPROTO; + pending->sk_state = SS_UNCONNECTED; + sock_put(pending); + } else { + pending->sk_state = SS_CONNECTED; + vsock_insert_connected(vpending); + + vsock_remove_pending(sk, pending); + vsock_enqueue_accept(sk, pending); + + sk->sk_data_ready(sk); + } + err = 0; + break; + default: + pr_debug("%s: sk->sk_ack_backlog=%d\n", __func__, + sk->sk_ack_backlog); + virtio_transport_send_reset(vpending, pkt); + err = -EINVAL; + break; + } + if (err < 0) + vsock_remove_pending(sk, pending); + release_sock(pending); + + /* Release refcnt obtained in virtio_transport_get_pending */ + sock_put(pending); + break; + default: + break; + } + + return 0; +} + +static void virtio_transport_space_update(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_transport *trans = vsk->trans; + bool space_available; + + /* buf_alloc and fwd_cnt is always included in the hdr */ + mutex_lock(&trans->tx_lock); + trans->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); + trans->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); + space_available = virtio_transport_has_space(vsk); + mutex_unlock(&trans->tx_lock); + + if (space_available) + sk->sk_write_space(sk); +} + +/* We are under the virtio-vsock's vsock->rx_lock or + * vhost-vsock's vq->mutex lock */ +void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) +{ + struct virtio_transport *trans; + struct sockaddr_vm src, dst; + struct vsock_sock *vsk; + struct sock *sk; + + vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); + vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port)); + + virtio_vsock_dumppkt(__func__, pkt); + + if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_DGRAM) { + sk = vsock_find_unbound_socket(&dst); + if (!sk) + goto free_pkt; + + vsk = vsock_sk(sk); + trans = vsk->trans; + BUG_ON(!trans); + + virtio_transport_space_update(sk, pkt); + + lock_sock(sk); + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + virtio_transport_free_pkt(pkt); + break; + case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_DGRAM, + &pkt->hdr); + virtio_transport_free_pkt(pkt); + break; + case VIRTIO_VSOCK_OP_RW: + virtio_transport_recv_dgram(sk, pkt); + break; + default: + virtio_transport_free_pkt(pkt); + break; + } + release_sock(sk); + + /* Release refcnt obtained when we fetched this socket out of + * the unbound list. + */ + sock_put(sk); + return; + } else if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) { + /* The socket must be in connected or bound table + * otherwise send reset back + */ + sk = vsock_find_connected_socket(&src, &dst); + if (!sk) { + sk = vsock_find_bound_socket(&dst); + if (!sk) { + pr_debug("%s: can not find bound_socket\n", __func__); + virtio_vsock_dumppkt(__func__, pkt); + /* Ignore this pkt instead of sending reset back */ + /* TODO send a RST unless this packet is a RST (to avoid infinite loops) */ + goto free_pkt; + } + } + + vsk = vsock_sk(sk); + trans = vsk->trans; + BUG_ON(!trans); + + virtio_transport_space_update(sk, pkt); + + lock_sock(sk); + switch (sk->sk_state) { + case VSOCK_SS_LISTEN: + virtio_transport_recv_listen(sk, pkt); + virtio_transport_free_pkt(pkt); + break; + case SS_CONNECTING: + virtio_transport_recv_connecting(sk, pkt); + virtio_transport_free_pkt(pkt); + break; + case SS_CONNECTED: + virtio_transport_recv_connected(sk, pkt); + break; + default: + virtio_transport_free_pkt(pkt); + break; + } + release_sock(sk); + + /* Release refcnt obtained when we fetched this socket out of the + * bound or connected list. + */ + sock_put(sk); + } + return; + +free_pkt: + virtio_transport_free_pkt(pkt); +} +EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); + +void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) +{ + kfree(pkt->buf); + kfree(pkt); +} +EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); + +static int __init virtio_vsock_common_init(void) +{ + get_random_bytes(vsockcookie_secret, sizeof(vsockcookie_secret)); + return 0; +} + +static void __exit virtio_vsock_common_exit(void) +{ +} + +module_init(virtio_vsock_common_init); +module_exit(virtio_vsock_common_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("common code for virtio vsock"); |