summaryrefslogtreecommitdiff
path: root/drivers/vhost/scsi.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/scsi.c')
-rw-r--r--drivers/vhost/scsi.c103
1 files changed, 83 insertions, 20 deletions
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 7aeff435c1d8..38d243d914d0 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -571,6 +571,9 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work *work)
int ret;
llnode = llist_del_all(&svq->completion_list);
+
+ mutex_lock(&svq->vq.mutex);
+
llist_for_each_entry_safe(cmd, t, llnode, tvc_completion_list) {
se_cmd = &cmd->tvc_se_cmd;
@@ -604,6 +607,8 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work *work)
vhost_scsi_release_cmd_res(se_cmd);
}
+ mutex_unlock(&svq->vq.mutex);
+
if (signal)
vhost_signal(&svq->vs->dev, &svq->vq);
}
@@ -630,7 +635,7 @@ vhost_scsi_get_cmd(struct vhost_virtqueue *vq, struct vhost_scsi_tpg *tpg,
tag = sbitmap_get(&svq->scsi_tags);
if (tag < 0) {
- pr_err("Unable to obtain tag for vhost_scsi_cmd\n");
+ pr_warn_once("Guest sent too many cmds. Returning TASK_SET_FULL.\n");
return ERR_PTR(-ENOMEM);
}
@@ -757,7 +762,7 @@ vhost_scsi_copy_iov_to_sgl(struct vhost_scsi_cmd *cmd, struct iov_iter *iter,
size_t len = iov_iter_count(iter);
unsigned int nbytes = 0;
struct page *page;
- int i;
+ int i, ret;
if (cmd->tvc_data_direction == DMA_FROM_DEVICE) {
cmd->saved_iter_addr = dup_iter(&cmd->saved_iter, iter,
@@ -770,6 +775,7 @@ vhost_scsi_copy_iov_to_sgl(struct vhost_scsi_cmd *cmd, struct iov_iter *iter,
page = alloc_page(GFP_KERNEL);
if (!page) {
i--;
+ ret = -ENOMEM;
goto err;
}
@@ -777,8 +783,10 @@ vhost_scsi_copy_iov_to_sgl(struct vhost_scsi_cmd *cmd, struct iov_iter *iter,
sg_set_page(&sg[i], page, nbytes, 0);
if (cmd->tvc_data_direction == DMA_TO_DEVICE &&
- copy_page_from_iter(page, 0, nbytes, iter) != nbytes)
+ copy_page_from_iter(page, 0, nbytes, iter) != nbytes) {
+ ret = -EFAULT;
goto err;
+ }
len -= nbytes;
}
@@ -793,7 +801,7 @@ err:
for (; i >= 0; i--)
__free_page(sg_page(&sg[i]));
kfree(cmd->saved_iter_addr);
- return -ENOMEM;
+ return ret;
}
static int
@@ -930,24 +938,69 @@ static void vhost_scsi_target_queue_cmd(struct vhost_scsi_cmd *cmd)
}
static void
-vhost_scsi_send_bad_target(struct vhost_scsi *vs,
- struct vhost_virtqueue *vq,
- int head, unsigned out)
+vhost_scsi_send_status(struct vhost_scsi *vs, struct vhost_virtqueue *vq,
+ struct vhost_scsi_ctx *vc, u8 status)
{
- struct virtio_scsi_cmd_resp __user *resp;
struct virtio_scsi_cmd_resp rsp;
+ struct iov_iter iov_iter;
int ret;
memset(&rsp, 0, sizeof(rsp));
- rsp.response = VIRTIO_SCSI_S_BAD_TARGET;
- resp = vq->iov[out].iov_base;
- ret = __copy_to_user(resp, &rsp, sizeof(rsp));
- if (!ret)
- vhost_add_used_and_signal(&vs->dev, vq, head, 0);
+ rsp.status = status;
+
+ iov_iter_init(&iov_iter, ITER_DEST, &vq->iov[vc->out], vc->in,
+ sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+
+ if (likely(ret == sizeof(rsp)))
+ vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
else
pr_err("Faulted on virtio_scsi_cmd_resp\n");
}
+#define TYPE_IO_CMD 0
+#define TYPE_CTRL_TMF 1
+#define TYPE_CTRL_AN 2
+
+static void
+vhost_scsi_send_bad_target(struct vhost_scsi *vs,
+ struct vhost_virtqueue *vq,
+ struct vhost_scsi_ctx *vc, int type)
+{
+ union {
+ struct virtio_scsi_cmd_resp cmd;
+ struct virtio_scsi_ctrl_tmf_resp tmf;
+ struct virtio_scsi_ctrl_an_resp an;
+ } rsp;
+ struct iov_iter iov_iter;
+ size_t rsp_size;
+ int ret;
+
+ memset(&rsp, 0, sizeof(rsp));
+
+ if (type == TYPE_IO_CMD) {
+ rsp_size = sizeof(struct virtio_scsi_cmd_resp);
+ rsp.cmd.response = VIRTIO_SCSI_S_BAD_TARGET;
+ } else if (type == TYPE_CTRL_TMF) {
+ rsp_size = sizeof(struct virtio_scsi_ctrl_tmf_resp);
+ rsp.tmf.response = VIRTIO_SCSI_S_BAD_TARGET;
+ } else {
+ rsp_size = sizeof(struct virtio_scsi_ctrl_an_resp);
+ rsp.an.response = VIRTIO_SCSI_S_BAD_TARGET;
+ }
+
+ iov_iter_init(&iov_iter, ITER_DEST, &vq->iov[vc->out], vc->in,
+ rsp_size);
+
+ ret = copy_to_iter(&rsp, rsp_size, &iov_iter);
+
+ if (likely(ret == rsp_size))
+ vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
+ else
+ pr_err("Faulted on virtio scsi type=%d\n", type);
+}
+
static int
vhost_scsi_get_desc(struct vhost_scsi *vs, struct vhost_virtqueue *vq,
struct vhost_scsi_ctx *vc)
@@ -1216,8 +1269,8 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
exp_data_len + prot_bytes,
data_direction);
if (IS_ERR(cmd)) {
- vq_err(vq, "vhost_scsi_get_cmd failed %ld\n",
- PTR_ERR(cmd));
+ ret = PTR_ERR(cmd);
+ vq_err(vq, "vhost_scsi_get_tag failed %dd\n", ret);
goto err;
}
cmd->tvc_vhost = vs;
@@ -1232,9 +1285,9 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
" %d\n", cmd, exp_data_len, prot_bytes, data_direction);
if (data_direction != DMA_NONE) {
- if (unlikely(vhost_scsi_mapal(cmd, prot_bytes,
- &prot_iter, exp_data_len,
- &data_iter))) {
+ ret = vhost_scsi_mapal(cmd, prot_bytes, &prot_iter,
+ exp_data_len, &data_iter);
+ if (unlikely(ret)) {
vq_err(vq, "Failed to map iov to sgl\n");
vhost_scsi_release_cmd_res(&cmd->tvc_se_cmd);
goto err;
@@ -1254,11 +1307,15 @@ err:
* EINVAL: Invalid response buffer, drop the request
* EIO: Respond with bad target
* EAGAIN: Pending request
+ * ENOMEM: Could not allocate resources for request
*/
if (ret == -ENXIO)
break;
else if (ret == -EIO)
- vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
+ vhost_scsi_send_bad_target(vs, vq, &vc, TYPE_IO_CMD);
+ else if (ret == -ENOMEM)
+ vhost_scsi_send_status(vs, vq, &vc,
+ SAM_STAT_TASK_SET_FULL);
} while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
out:
mutex_unlock(&vq->mutex);
@@ -1297,8 +1354,11 @@ static void vhost_scsi_tmf_resp_work(struct vhost_work *work)
else
resp_code = VIRTIO_SCSI_S_FUNCTION_REJECTED;
+ mutex_lock(&tmf->svq->vq.mutex);
vhost_scsi_send_tmf_resp(tmf->vhost, &tmf->svq->vq, tmf->in_iovs,
tmf->vq_desc, &tmf->resp_iov, resp_code);
+ mutex_unlock(&tmf->svq->vq.mutex);
+
vhost_scsi_release_tmf_res(tmf);
}
@@ -1488,7 +1548,10 @@ err:
if (ret == -ENXIO)
break;
else if (ret == -EIO)
- vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
+ vhost_scsi_send_bad_target(vs, vq, &vc,
+ v_req.type == VIRTIO_SCSI_T_TMF ?
+ TYPE_CTRL_TMF :
+ TYPE_CTRL_AN);
} while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
out:
mutex_unlock(&vq->mutex);