diff options
Diffstat (limited to 'drivers')
24 files changed, 1522 insertions, 253 deletions
diff --git a/drivers/crypto/virtio/Kconfig b/drivers/crypto/virtio/Kconfig index b894e3a8be4f..5f8915f4a9ff 100644 --- a/drivers/crypto/virtio/Kconfig +++ b/drivers/crypto/virtio/Kconfig @@ -3,8 +3,11 @@ config CRYPTO_DEV_VIRTIO tristate "VirtIO crypto driver" depends on VIRTIO select CRYPTO_AEAD + select CRYPTO_AKCIPHER2 select CRYPTO_SKCIPHER select CRYPTO_ENGINE + select CRYPTO_RSA + select MPILIB help This driver provides support for virtio crypto device. If you choose 'M' here, this module will be called virtio_crypto. diff --git a/drivers/crypto/virtio/Makefile b/drivers/crypto/virtio/Makefile index cbfccccfa135..bfa6cbae342e 100644 --- a/drivers/crypto/virtio/Makefile +++ b/drivers/crypto/virtio/Makefile @@ -1,6 +1,7 @@ # SPDX-License-Identifier: GPL-2.0 obj-$(CONFIG_CRYPTO_DEV_VIRTIO) += virtio_crypto.o virtio_crypto-objs := \ - virtio_crypto_algs.o \ + virtio_crypto_skcipher_algs.o \ + virtio_crypto_akcipher_algs.o \ virtio_crypto_mgr.o \ virtio_crypto_core.o diff --git a/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c b/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c new file mode 100644 index 000000000000..f3ec9420215e --- /dev/null +++ b/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c @@ -0,0 +1,585 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + /* Asymmetric algorithms supported by virtio crypto device + * + * Authors: zhenwei pi <pizhenwei@bytedance.com> + * lei he <helei.sig11@bytedance.com> + * + * Copyright 2022 Bytedance CO., LTD. + */ + +#include <linux/mpi.h> +#include <linux/scatterlist.h> +#include <crypto/algapi.h> +#include <crypto/internal/akcipher.h> +#include <crypto/internal/rsa.h> +#include <linux/err.h> +#include <crypto/scatterwalk.h> +#include <linux/atomic.h> + +#include <uapi/linux/virtio_crypto.h> +#include "virtio_crypto_common.h" + +struct virtio_crypto_rsa_ctx { + MPI n; +}; + +struct virtio_crypto_akcipher_ctx { + struct crypto_engine_ctx enginectx; + struct virtio_crypto *vcrypto; + struct crypto_akcipher *tfm; + bool session_valid; + __u64 session_id; + union { + struct virtio_crypto_rsa_ctx rsa_ctx; + }; +}; + +struct virtio_crypto_akcipher_request { + struct virtio_crypto_request base; + struct virtio_crypto_akcipher_ctx *akcipher_ctx; + struct akcipher_request *akcipher_req; + void *src_buf; + void *dst_buf; + uint32_t opcode; +}; + +struct virtio_crypto_akcipher_algo { + uint32_t algonum; + uint32_t service; + unsigned int active_devs; + struct akcipher_alg algo; +}; + +static DEFINE_MUTEX(algs_lock); + +static void virtio_crypto_akcipher_finalize_req( + struct virtio_crypto_akcipher_request *vc_akcipher_req, + struct akcipher_request *req, int err) +{ + virtcrypto_clear_request(&vc_akcipher_req->base); + + crypto_finalize_akcipher_request(vc_akcipher_req->base.dataq->engine, req, err); +} + +static void virtio_crypto_dataq_akcipher_callback(struct virtio_crypto_request *vc_req, int len) +{ + struct virtio_crypto_akcipher_request *vc_akcipher_req = + container_of(vc_req, struct virtio_crypto_akcipher_request, base); + struct akcipher_request *akcipher_req; + int error; + + switch (vc_req->status) { + case VIRTIO_CRYPTO_OK: + error = 0; + break; + case VIRTIO_CRYPTO_INVSESS: + case VIRTIO_CRYPTO_ERR: + error = -EINVAL; + break; + case VIRTIO_CRYPTO_BADMSG: + error = -EBADMSG; + break; + + case VIRTIO_CRYPTO_KEY_REJECTED: + error = -EKEYREJECTED; + break; + + default: + error = -EIO; + break; + } + + akcipher_req = vc_akcipher_req->akcipher_req; + if (vc_akcipher_req->opcode != VIRTIO_CRYPTO_AKCIPHER_VERIFY) + sg_copy_from_buffer(akcipher_req->dst, sg_nents(akcipher_req->dst), + vc_akcipher_req->dst_buf, akcipher_req->dst_len); + virtio_crypto_akcipher_finalize_req(vc_akcipher_req, akcipher_req, error); +} + +static int virtio_crypto_alg_akcipher_init_session(struct virtio_crypto_akcipher_ctx *ctx, + struct virtio_crypto_ctrl_header *header, void *para, + const uint8_t *key, unsigned int keylen) +{ + struct scatterlist outhdr_sg, key_sg, inhdr_sg, *sgs[3]; + struct virtio_crypto *vcrypto = ctx->vcrypto; + uint8_t *pkey; + unsigned int inlen; + int err; + unsigned int num_out = 0, num_in = 0; + + pkey = kmemdup(key, keylen, GFP_ATOMIC); + if (!pkey) + return -ENOMEM; + + spin_lock(&vcrypto->ctrl_lock); + memcpy(&vcrypto->ctrl.header, header, sizeof(vcrypto->ctrl.header)); + memcpy(&vcrypto->ctrl.u, para, sizeof(vcrypto->ctrl.u)); + vcrypto->input.status = cpu_to_le32(VIRTIO_CRYPTO_ERR); + + sg_init_one(&outhdr_sg, &vcrypto->ctrl, sizeof(vcrypto->ctrl)); + sgs[num_out++] = &outhdr_sg; + + sg_init_one(&key_sg, pkey, keylen); + sgs[num_out++] = &key_sg; + + sg_init_one(&inhdr_sg, &vcrypto->input, sizeof(vcrypto->input)); + sgs[num_out + num_in++] = &inhdr_sg; + + err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, num_in, vcrypto, GFP_ATOMIC); + if (err < 0) + goto out; + + virtqueue_kick(vcrypto->ctrl_vq); + while (!virtqueue_get_buf(vcrypto->ctrl_vq, &inlen) && + !virtqueue_is_broken(vcrypto->ctrl_vq)) + cpu_relax(); + + if (le32_to_cpu(vcrypto->input.status) != VIRTIO_CRYPTO_OK) { + err = -EINVAL; + goto out; + } + + ctx->session_id = le64_to_cpu(vcrypto->input.session_id); + ctx->session_valid = true; + err = 0; + +out: + spin_unlock(&vcrypto->ctrl_lock); + kfree_sensitive(pkey); + + if (err < 0) + pr_err("virtio_crypto: Create session failed status: %u\n", + le32_to_cpu(vcrypto->input.status)); + + return err; +} + +static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akcipher_ctx *ctx) +{ + struct scatterlist outhdr_sg, inhdr_sg, *sgs[2]; + struct virtio_crypto_destroy_session_req *destroy_session; + struct virtio_crypto *vcrypto = ctx->vcrypto; + unsigned int num_out = 0, num_in = 0, inlen; + int err; + + spin_lock(&vcrypto->ctrl_lock); + if (!ctx->session_valid) { + err = 0; + goto out; + } + vcrypto->ctrl_status.status = VIRTIO_CRYPTO_ERR; + vcrypto->ctrl.header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION); + vcrypto->ctrl.header.queue_id = 0; + + destroy_session = &vcrypto->ctrl.u.destroy_session; + destroy_session->session_id = cpu_to_le64(ctx->session_id); + + sg_init_one(&outhdr_sg, &vcrypto->ctrl, sizeof(vcrypto->ctrl)); + sgs[num_out++] = &outhdr_sg; + + sg_init_one(&inhdr_sg, &vcrypto->ctrl_status.status, sizeof(vcrypto->ctrl_status.status)); + sgs[num_out + num_in++] = &inhdr_sg; + + err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, num_in, vcrypto, GFP_ATOMIC); + if (err < 0) + goto out; + + virtqueue_kick(vcrypto->ctrl_vq); + while (!virtqueue_get_buf(vcrypto->ctrl_vq, &inlen) && + !virtqueue_is_broken(vcrypto->ctrl_vq)) + cpu_relax(); + + if (vcrypto->ctrl_status.status != VIRTIO_CRYPTO_OK) { + err = -EINVAL; + goto out; + } + + err = 0; + ctx->session_valid = false; + +out: + spin_unlock(&vcrypto->ctrl_lock); + if (err < 0) { + pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n", + vcrypto->ctrl_status.status, destroy_session->session_id); + } + + return err; +} + +static int __virtio_crypto_akcipher_do_req(struct virtio_crypto_akcipher_request *vc_akcipher_req, + struct akcipher_request *req, struct data_queue *data_vq) +{ + struct virtio_crypto_akcipher_ctx *ctx = vc_akcipher_req->akcipher_ctx; + struct virtio_crypto_request *vc_req = &vc_akcipher_req->base; + struct virtio_crypto *vcrypto = ctx->vcrypto; + struct virtio_crypto_op_data_req *req_data = vc_req->req_data; + struct scatterlist *sgs[4], outhdr_sg, inhdr_sg, srcdata_sg, dstdata_sg; + void *src_buf = NULL, *dst_buf = NULL; + unsigned int num_out = 0, num_in = 0; + int node = dev_to_node(&vcrypto->vdev->dev); + unsigned long flags; + int ret = -ENOMEM; + bool verify = vc_akcipher_req->opcode == VIRTIO_CRYPTO_AKCIPHER_VERIFY; + unsigned int src_len = verify ? req->src_len + req->dst_len : req->src_len; + + /* out header */ + sg_init_one(&outhdr_sg, req_data, sizeof(*req_data)); + sgs[num_out++] = &outhdr_sg; + + /* src data */ + src_buf = kcalloc_node(src_len, 1, GFP_KERNEL, node); + if (!src_buf) + goto err; + + if (verify) { + /* for verify operation, both src and dst data work as OUT direction */ + sg_copy_to_buffer(req->src, sg_nents(req->src), src_buf, src_len); + sg_init_one(&srcdata_sg, src_buf, src_len); + sgs[num_out++] = &srcdata_sg; + } else { + sg_copy_to_buffer(req->src, sg_nents(req->src), src_buf, src_len); + sg_init_one(&srcdata_sg, src_buf, src_len); + sgs[num_out++] = &srcdata_sg; + + /* dst data */ + dst_buf = kcalloc_node(req->dst_len, 1, GFP_KERNEL, node); + if (!dst_buf) + goto err; + + sg_init_one(&dstdata_sg, dst_buf, req->dst_len); + sgs[num_out + num_in++] = &dstdata_sg; + } + + vc_akcipher_req->src_buf = src_buf; + vc_akcipher_req->dst_buf = dst_buf; + + /* in header */ + sg_init_one(&inhdr_sg, &vc_req->status, sizeof(vc_req->status)); + sgs[num_out + num_in++] = &inhdr_sg; + + spin_lock_irqsave(&data_vq->lock, flags); + ret = virtqueue_add_sgs(data_vq->vq, sgs, num_out, num_in, vc_req, GFP_ATOMIC); + virtqueue_kick(data_vq->vq); + spin_unlock_irqrestore(&data_vq->lock, flags); + if (ret) + goto err; + + return 0; + +err: + kfree(src_buf); + kfree(dst_buf); + + return -ENOMEM; +} + +static int virtio_crypto_rsa_do_req(struct crypto_engine *engine, void *vreq) +{ + struct akcipher_request *req = container_of(vreq, struct akcipher_request, base); + struct virtio_crypto_akcipher_request *vc_akcipher_req = akcipher_request_ctx(req); + struct virtio_crypto_request *vc_req = &vc_akcipher_req->base; + struct virtio_crypto_akcipher_ctx *ctx = vc_akcipher_req->akcipher_ctx; + struct virtio_crypto *vcrypto = ctx->vcrypto; + struct data_queue *data_vq = vc_req->dataq; + struct virtio_crypto_op_header *header; + struct virtio_crypto_akcipher_data_req *akcipher_req; + int ret; + + vc_req->sgs = NULL; + vc_req->req_data = kzalloc_node(sizeof(*vc_req->req_data), + GFP_KERNEL, dev_to_node(&vcrypto->vdev->dev)); + if (!vc_req->req_data) + return -ENOMEM; + + /* build request header */ + header = &vc_req->req_data->header; + header->opcode = cpu_to_le32(vc_akcipher_req->opcode); + header->algo = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA); + header->session_id = cpu_to_le64(ctx->session_id); + + /* build request akcipher data */ + akcipher_req = &vc_req->req_data->u.akcipher_req; + akcipher_req->para.src_data_len = cpu_to_le32(req->src_len); + akcipher_req->para.dst_data_len = cpu_to_le32(req->dst_len); + + ret = __virtio_crypto_akcipher_do_req(vc_akcipher_req, req, data_vq); + if (ret < 0) { + kfree_sensitive(vc_req->req_data); + vc_req->req_data = NULL; + return ret; + } + + return 0; +} + +static int virtio_crypto_rsa_req(struct akcipher_request *req, uint32_t opcode) +{ + struct crypto_akcipher *atfm = crypto_akcipher_reqtfm(req); + struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(atfm); + struct virtio_crypto_akcipher_request *vc_akcipher_req = akcipher_request_ctx(req); + struct virtio_crypto_request *vc_req = &vc_akcipher_req->base; + struct virtio_crypto *vcrypto = ctx->vcrypto; + /* Use the first data virtqueue as default */ + struct data_queue *data_vq = &vcrypto->data_vq[0]; + + vc_req->dataq = data_vq; + vc_req->alg_cb = virtio_crypto_dataq_akcipher_callback; + vc_akcipher_req->akcipher_ctx = ctx; + vc_akcipher_req->akcipher_req = req; + vc_akcipher_req->opcode = opcode; + + return crypto_transfer_akcipher_request_to_engine(data_vq->engine, req); +} + +static int virtio_crypto_rsa_encrypt(struct akcipher_request *req) +{ + return virtio_crypto_rsa_req(req, VIRTIO_CRYPTO_AKCIPHER_ENCRYPT); +} + +static int virtio_crypto_rsa_decrypt(struct akcipher_request *req) +{ + return virtio_crypto_rsa_req(req, VIRTIO_CRYPTO_AKCIPHER_DECRYPT); +} + +static int virtio_crypto_rsa_sign(struct akcipher_request *req) +{ + return virtio_crypto_rsa_req(req, VIRTIO_CRYPTO_AKCIPHER_SIGN); +} + +static int virtio_crypto_rsa_verify(struct akcipher_request *req) +{ + return virtio_crypto_rsa_req(req, VIRTIO_CRYPTO_AKCIPHER_VERIFY); +} + +static int virtio_crypto_rsa_set_key(struct crypto_akcipher *tfm, + const void *key, + unsigned int keylen, + bool private, + int padding_algo, + int hash_algo) +{ + struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); + struct virtio_crypto_rsa_ctx *rsa_ctx = &ctx->rsa_ctx; + struct virtio_crypto *vcrypto; + struct virtio_crypto_ctrl_header header; + struct virtio_crypto_akcipher_session_para para; + struct rsa_key rsa_key = {0}; + int node = virtio_crypto_get_current_node(); + uint32_t keytype; + int ret; + + /* mpi_free will test n, just free it. */ + mpi_free(rsa_ctx->n); + rsa_ctx->n = NULL; + + if (private) { + keytype = VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PRIVATE; + ret = rsa_parse_priv_key(&rsa_key, key, keylen); + } else { + keytype = VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PUBLIC; + ret = rsa_parse_pub_key(&rsa_key, key, keylen); + } + + if (ret) + return ret; + + rsa_ctx->n = mpi_read_raw_data(rsa_key.n, rsa_key.n_sz); + if (!rsa_ctx->n) + return -ENOMEM; + + if (!ctx->vcrypto) { + vcrypto = virtcrypto_get_dev_node(node, VIRTIO_CRYPTO_SERVICE_AKCIPHER, + VIRTIO_CRYPTO_AKCIPHER_RSA); + if (!vcrypto) { + pr_err("virtio_crypto: Could not find a virtio device in the system or unsupported algo\n"); + return -ENODEV; + } + + ctx->vcrypto = vcrypto; + } else { + virtio_crypto_alg_akcipher_close_session(ctx); + } + + /* set ctrl header */ + header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_CREATE_SESSION); + header.algo = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA); + header.queue_id = 0; + + /* set RSA para */ + para.algo = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA); + para.keytype = cpu_to_le32(keytype); + para.keylen = cpu_to_le32(keylen); + para.u.rsa.padding_algo = cpu_to_le32(padding_algo); + para.u.rsa.hash_algo = cpu_to_le32(hash_algo); + + return virtio_crypto_alg_akcipher_init_session(ctx, &header, ¶, key, keylen); +} + +static int virtio_crypto_rsa_raw_set_priv_key(struct crypto_akcipher *tfm, + const void *key, + unsigned int keylen) +{ + return virtio_crypto_rsa_set_key(tfm, key, keylen, 1, + VIRTIO_CRYPTO_RSA_RAW_PADDING, + VIRTIO_CRYPTO_RSA_NO_HASH); +} + + +static int virtio_crypto_p1pad_rsa_sha1_set_priv_key(struct crypto_akcipher *tfm, + const void *key, + unsigned int keylen) +{ + return virtio_crypto_rsa_set_key(tfm, key, keylen, 1, + VIRTIO_CRYPTO_RSA_PKCS1_PADDING, + VIRTIO_CRYPTO_RSA_SHA1); +} + +static int virtio_crypto_rsa_raw_set_pub_key(struct crypto_akcipher *tfm, + const void *key, + unsigned int keylen) +{ + return virtio_crypto_rsa_set_key(tfm, key, keylen, 0, + VIRTIO_CRYPTO_RSA_RAW_PADDING, + VIRTIO_CRYPTO_RSA_NO_HASH); +} + +static int virtio_crypto_p1pad_rsa_sha1_set_pub_key(struct crypto_akcipher *tfm, + const void *key, + unsigned int keylen) +{ + return virtio_crypto_rsa_set_key(tfm, key, keylen, 0, + VIRTIO_CRYPTO_RSA_PKCS1_PADDING, + VIRTIO_CRYPTO_RSA_SHA1); +} + +static unsigned int virtio_crypto_rsa_max_size(struct crypto_akcipher *tfm) +{ + struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); + struct virtio_crypto_rsa_ctx *rsa_ctx = &ctx->rsa_ctx; + + return mpi_get_size(rsa_ctx->n); +} + +static int virtio_crypto_rsa_init_tfm(struct crypto_akcipher *tfm) +{ + struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); + + ctx->tfm = tfm; + ctx->enginectx.op.do_one_request = virtio_crypto_rsa_do_req; + ctx->enginectx.op.prepare_request = NULL; + ctx->enginectx.op.unprepare_request = NULL; + + return 0; +} + +static void virtio_crypto_rsa_exit_tfm(struct crypto_akcipher *tfm) +{ + struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); + struct virtio_crypto_rsa_ctx *rsa_ctx = &ctx->rsa_ctx; + + virtio_crypto_alg_akcipher_close_session(ctx); + virtcrypto_dev_put(ctx->vcrypto); + mpi_free(rsa_ctx->n); + rsa_ctx->n = NULL; +} + +static struct virtio_crypto_akcipher_algo virtio_crypto_akcipher_algs[] = { + { + .algonum = VIRTIO_CRYPTO_AKCIPHER_RSA, + .service = VIRTIO_CRYPTO_SERVICE_AKCIPHER, + .algo = { + .encrypt = virtio_crypto_rsa_encrypt, + .decrypt = virtio_crypto_rsa_decrypt, + .set_pub_key = virtio_crypto_rsa_raw_set_pub_key, + .set_priv_key = virtio_crypto_rsa_raw_set_priv_key, + .max_size = virtio_crypto_rsa_max_size, + .init = virtio_crypto_rsa_init_tfm, + .exit = virtio_crypto_rsa_exit_tfm, + .reqsize = sizeof(struct virtio_crypto_akcipher_request), + .base = { + .cra_name = "rsa", + .cra_driver_name = "virtio-crypto-rsa", + .cra_priority = 150, + .cra_module = THIS_MODULE, + .cra_ctxsize = sizeof(struct virtio_crypto_akcipher_ctx), + }, + }, + }, + { + .algonum = VIRTIO_CRYPTO_AKCIPHER_RSA, + .service = VIRTIO_CRYPTO_SERVICE_AKCIPHER, + .algo = { + .encrypt = virtio_crypto_rsa_encrypt, + .decrypt = virtio_crypto_rsa_decrypt, + .sign = virtio_crypto_rsa_sign, + .verify = virtio_crypto_rsa_verify, + .set_pub_key = virtio_crypto_p1pad_rsa_sha1_set_pub_key, + .set_priv_key = virtio_crypto_p1pad_rsa_sha1_set_priv_key, + .max_size = virtio_crypto_rsa_max_size, + .init = virtio_crypto_rsa_init_tfm, + .exit = virtio_crypto_rsa_exit_tfm, + .reqsize = sizeof(struct virtio_crypto_akcipher_request), + .base = { + .cra_name = "pkcs1pad(rsa,sha1)", + .cra_driver_name = "virtio-pkcs1-rsa-with-sha1", + .cra_priority = 150, + .cra_module = THIS_MODULE, + .cra_ctxsize = sizeof(struct virtio_crypto_akcipher_ctx), + }, + }, + }, +}; + +int virtio_crypto_akcipher_algs_register(struct virtio_crypto *vcrypto) +{ + int ret = 0; + int i = 0; + + mutex_lock(&algs_lock); + + for (i = 0; i < ARRAY_SIZE(virtio_crypto_akcipher_algs); i++) { + uint32_t service = virtio_crypto_akcipher_algs[i].service; + uint32_t algonum = virtio_crypto_akcipher_algs[i].algonum; + + if (!virtcrypto_algo_is_supported(vcrypto, service, algonum)) + continue; + + if (virtio_crypto_akcipher_algs[i].active_devs == 0) { + ret = crypto_register_akcipher(&virtio_crypto_akcipher_algs[i].algo); + if (ret) + goto unlock; + } + + virtio_crypto_akcipher_algs[i].active_devs++; + dev_info(&vcrypto->vdev->dev, "Registered akcipher algo %s\n", + virtio_crypto_akcipher_algs[i].algo.base.cra_name); + } + +unlock: + mutex_unlock(&algs_lock); + return ret; +} + +void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto *vcrypto) +{ + int i = 0; + + mutex_lock(&algs_lock); + + for (i = 0; i < ARRAY_SIZE(virtio_crypto_akcipher_algs); i++) { + uint32_t service = virtio_crypto_akcipher_algs[i].service; + uint32_t algonum = virtio_crypto_akcipher_algs[i].algonum; + + if (virtio_crypto_akcipher_algs[i].active_devs == 0 || + !virtcrypto_algo_is_supported(vcrypto, service, algonum)) + continue; + + if (virtio_crypto_akcipher_algs[i].active_devs == 1) + crypto_unregister_akcipher(&virtio_crypto_akcipher_algs[i].algo); + + virtio_crypto_akcipher_algs[i].active_devs--; + } + + mutex_unlock(&algs_lock); +} diff --git a/drivers/crypto/virtio/virtio_crypto_common.h b/drivers/crypto/virtio/virtio_crypto_common.h index a24f85c589e7..e693d4ee83a6 100644 --- a/drivers/crypto/virtio/virtio_crypto_common.h +++ b/drivers/crypto/virtio/virtio_crypto_common.h @@ -56,6 +56,7 @@ struct virtio_crypto { u32 mac_algo_l; u32 mac_algo_h; u32 aead_algo; + u32 akcipher_algo; /* Maximum length of cipher key */ u32 max_cipher_key_len; @@ -129,7 +130,9 @@ static inline int virtio_crypto_get_current_node(void) return node; } -int virtio_crypto_algs_register(struct virtio_crypto *vcrypto); -void virtio_crypto_algs_unregister(struct virtio_crypto *vcrypto); +int virtio_crypto_skcipher_algs_register(struct virtio_crypto *vcrypto); +void virtio_crypto_skcipher_algs_unregister(struct virtio_crypto *vcrypto); +int virtio_crypto_akcipher_algs_register(struct virtio_crypto *vcrypto); +void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto *vcrypto); #endif /* _VIRTIO_CRYPTO_COMMON_H */ diff --git a/drivers/crypto/virtio/virtio_crypto_core.c b/drivers/crypto/virtio/virtio_crypto_core.c index 8e977b7627cb..c6f482db0bc0 100644 --- a/drivers/crypto/virtio/virtio_crypto_core.c +++ b/drivers/crypto/virtio/virtio_crypto_core.c @@ -297,6 +297,7 @@ static int virtcrypto_probe(struct virtio_device *vdev) u32 mac_algo_l = 0; u32 mac_algo_h = 0; u32 aead_algo = 0; + u32 akcipher_algo = 0; u32 crypto_services = 0; if (!virtio_has_feature(vdev, VIRTIO_F_VERSION_1)) @@ -348,6 +349,9 @@ static int virtcrypto_probe(struct virtio_device *vdev) mac_algo_h, &mac_algo_h); virtio_cread_le(vdev, struct virtio_crypto_config, aead_algo, &aead_algo); + if (crypto_services & (1 << VIRTIO_CRYPTO_SERVICE_AKCIPHER)) + virtio_cread_le(vdev, struct virtio_crypto_config, + akcipher_algo, &akcipher_algo); /* Add virtio crypto device to global table */ err = virtcrypto_devmgr_add_dev(vcrypto); @@ -374,7 +378,7 @@ static int virtcrypto_probe(struct virtio_device *vdev) vcrypto->mac_algo_h = mac_algo_h; vcrypto->hash_algo = hash_algo; vcrypto->aead_algo = aead_algo; - + vcrypto->akcipher_algo = akcipher_algo; dev_info(&vdev->dev, "max_queues: %u, max_cipher_key_len: %u, max_auth_key_len: %u, max_size 0x%llx\n", diff --git a/drivers/crypto/virtio/virtio_crypto_mgr.c b/drivers/crypto/virtio/virtio_crypto_mgr.c index 6860f8180c7c..70e778aac0f2 100644 --- a/drivers/crypto/virtio/virtio_crypto_mgr.c +++ b/drivers/crypto/virtio/virtio_crypto_mgr.c @@ -237,8 +237,14 @@ struct virtio_crypto *virtcrypto_get_dev_node(int node, uint32_t service, */ int virtcrypto_dev_start(struct virtio_crypto *vcrypto) { - if (virtio_crypto_algs_register(vcrypto)) { - pr_err("virtio_crypto: Failed to register crypto algs\n"); + if (virtio_crypto_skcipher_algs_register(vcrypto)) { + pr_err("virtio_crypto: Failed to register crypto skcipher algs\n"); + return -EFAULT; + } + + if (virtio_crypto_akcipher_algs_register(vcrypto)) { + pr_err("virtio_crypto: Failed to register crypto akcipher algs\n"); + virtio_crypto_skcipher_algs_unregister(vcrypto); return -EFAULT; } @@ -257,7 +263,8 @@ int virtcrypto_dev_start(struct virtio_crypto *vcrypto) */ void virtcrypto_dev_stop(struct virtio_crypto *vcrypto) { - virtio_crypto_algs_unregister(vcrypto); + virtio_crypto_skcipher_algs_unregister(vcrypto); + virtio_crypto_akcipher_algs_unregister(vcrypto); } /* @@ -312,6 +319,10 @@ bool virtcrypto_algo_is_supported(struct virtio_crypto *vcrypto, case VIRTIO_CRYPTO_SERVICE_AEAD: algo_mask = vcrypto->aead_algo; break; + + case VIRTIO_CRYPTO_SERVICE_AKCIPHER: + algo_mask = vcrypto->akcipher_algo; + break; } if (!(algo_mask & (1u << algo))) diff --git a/drivers/crypto/virtio/virtio_crypto_algs.c b/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c index 583c0b535d13..a618c46a52b8 100644 --- a/drivers/crypto/virtio/virtio_crypto_algs.c +++ b/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c @@ -613,7 +613,7 @@ static struct virtio_crypto_algo virtio_crypto_algs[] = { { }, } }; -int virtio_crypto_algs_register(struct virtio_crypto *vcrypto) +int virtio_crypto_skcipher_algs_register(struct virtio_crypto *vcrypto) { int ret = 0; int i = 0; @@ -644,7 +644,7 @@ unlock: return ret; } -void virtio_crypto_algs_unregister(struct virtio_crypto *vcrypto) +void virtio_crypto_skcipher_algs_unregister(struct virtio_crypto *vcrypto) { int i = 0; diff --git a/drivers/net/virtio_net.c b/drivers/net/virtio_net.c index 11f26b00a226..87838cbe38cf 100644 --- a/drivers/net/virtio_net.c +++ b/drivers/net/virtio_net.c @@ -169,6 +169,24 @@ struct receive_queue { struct xdp_rxq_info xdp_rxq; }; +/* This structure can contain rss message with maximum settings for indirection table and keysize + * Note, that default structure that describes RSS configuration virtio_net_rss_config + * contains same info but can't handle table values. + * In any case, structure would be passed to virtio hw through sg_buf split by parts + * because table sizes may be differ according to the device configuration. + */ +#define VIRTIO_NET_RSS_MAX_KEY_SIZE 40 +#define VIRTIO_NET_RSS_MAX_TABLE_LEN 128 +struct virtio_net_ctrl_rss { + u32 hash_types; + u16 indirection_table_mask; + u16 unclassified_queue; + u16 indirection_table[VIRTIO_NET_RSS_MAX_TABLE_LEN]; + u16 max_tx_vq; + u8 hash_key_length; + u8 key[VIRTIO_NET_RSS_MAX_KEY_SIZE]; +}; + /* Control VQ buffers: protected by the rtnl lock */ struct control_buf { struct virtio_net_ctrl_hdr hdr; @@ -178,6 +196,7 @@ struct control_buf { u8 allmulti; __virtio16 vid; __virtio64 offloads; + struct virtio_net_ctrl_rss rss; }; struct virtnet_info { @@ -206,6 +225,14 @@ struct virtnet_info { /* Host will merge rx buffers for big packets (shake it! shake it!) */ bool mergeable_rx_bufs; + /* Host supports rss and/or hash report */ + bool has_rss; + bool has_rss_hash_report; + u8 rss_key_size; + u16 rss_indir_table_size; + u32 rss_hash_types_supported; + u32 rss_hash_types_saved; + /* Has control virtqueue */ bool has_cvq; @@ -242,13 +269,13 @@ struct virtnet_info { }; struct padded_vnet_hdr { - struct virtio_net_hdr_mrg_rxbuf hdr; + struct virtio_net_hdr_v1_hash hdr; /* * hdr is in a separate sg buffer, and data sg buffer shares same page * with this header sg. This padding makes next sg 16 byte aligned * after the header. */ - char padding[4]; + char padding[12]; }; static bool is_xdp_frame(void *ptr) @@ -396,7 +423,7 @@ static struct sk_buff *page_to_skb(struct virtnet_info *vi, hdr_len = vi->hdr_len; if (vi->mergeable_rx_bufs) - hdr_padded_len = sizeof(*hdr); + hdr_padded_len = hdr_len; else hdr_padded_len = sizeof(struct padded_vnet_hdr); @@ -1123,6 +1150,35 @@ xdp_xmit: return NULL; } +static void virtio_skb_set_hash(const struct virtio_net_hdr_v1_hash *hdr_hash, + struct sk_buff *skb) +{ + enum pkt_hash_types rss_hash_type; + + if (!hdr_hash || !skb) + return; + + switch ((int)hdr_hash->hash_report) { + case VIRTIO_NET_HASH_REPORT_TCPv4: + case VIRTIO_NET_HASH_REPORT_UDPv4: + case VIRTIO_NET_HASH_REPORT_TCPv6: + case VIRTIO_NET_HASH_REPORT_UDPv6: + case VIRTIO_NET_HASH_REPORT_TCPv6_EX: + case VIRTIO_NET_HASH_REPORT_UDPv6_EX: + rss_hash_type = PKT_HASH_TYPE_L4; + break; + case VIRTIO_NET_HASH_REPORT_IPv4: + case VIRTIO_NET_HASH_REPORT_IPv6: + case VIRTIO_NET_HASH_REPORT_IPv6_EX: + rss_hash_type = PKT_HASH_TYPE_L3; + break; + case VIRTIO_NET_HASH_REPORT_NONE: + default: + rss_hash_type = PKT_HASH_TYPE_NONE; + } + skb_set_hash(skb, (unsigned int)hdr_hash->hash_value, rss_hash_type); +} + static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq, void *buf, unsigned int len, void **ctx, unsigned int *xdp_xmit, @@ -1157,6 +1213,8 @@ static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq, return; hdr = skb_vnet_hdr(skb); + if (dev->features & NETIF_F_RXHASH && vi->has_rss_hash_report) + virtio_skb_set_hash((const struct virtio_net_hdr_v1_hash *)hdr, skb); if (hdr->hdr.flags & VIRTIO_NET_HDR_F_DATA_VALID) skb->ip_summed = CHECKSUM_UNNECESSARY; @@ -1266,7 +1324,8 @@ static unsigned int get_mergeable_buf_len(struct receive_queue *rq, struct ewma_pkt_len *avg_pkt_len, unsigned int room) { - const size_t hdr_len = sizeof(struct virtio_net_hdr_mrg_rxbuf); + struct virtnet_info *vi = rq->vq->vdev->priv; + const size_t hdr_len = vi->hdr_len; unsigned int len; if (room) @@ -2183,6 +2242,174 @@ static void virtnet_get_ringparam(struct net_device *dev, ring->tx_pending = ring->tx_max_pending; } +static bool virtnet_commit_rss_command(struct virtnet_info *vi) +{ + struct net_device *dev = vi->dev; + struct scatterlist sgs[4]; + unsigned int sg_buf_size; + + /* prepare sgs */ + sg_init_table(sgs, 4); + + sg_buf_size = offsetof(struct virtio_net_ctrl_rss, indirection_table); + sg_set_buf(&sgs[0], &vi->ctrl->rss, sg_buf_size); + + sg_buf_size = sizeof(uint16_t) * (vi->ctrl->rss.indirection_table_mask + 1); + sg_set_buf(&sgs[1], vi->ctrl->rss.indirection_table, sg_buf_size); + + sg_buf_size = offsetof(struct virtio_net_ctrl_rss, key) + - offsetof(struct virtio_net_ctrl_rss, max_tx_vq); + sg_set_buf(&sgs[2], &vi->ctrl->rss.max_tx_vq, sg_buf_size); + + sg_buf_size = vi->rss_key_size; + sg_set_buf(&sgs[3], vi->ctrl->rss.key, sg_buf_size); + + if (!virtnet_send_command(vi, VIRTIO_NET_CTRL_MQ, + vi->has_rss ? VIRTIO_NET_CTRL_MQ_RSS_CONFIG + : VIRTIO_NET_CTRL_MQ_HASH_CONFIG, sgs)) { + dev_warn(&dev->dev, "VIRTIONET issue with committing RSS sgs\n"); + return false; + } + return true; +} + +static void virtnet_init_default_rss(struct virtnet_info *vi) +{ + u32 indir_val = 0; + int i = 0; + + vi->ctrl->rss.hash_types = vi->rss_hash_types_supported; + vi->rss_hash_types_saved = vi->rss_hash_types_supported; + vi->ctrl->rss.indirection_table_mask = vi->rss_indir_table_size + ? vi->rss_indir_table_size - 1 : 0; + vi->ctrl->rss.unclassified_queue = 0; + + for (; i < vi->rss_indir_table_size; ++i) { + indir_val = ethtool_rxfh_indir_default(i, vi->curr_queue_pairs); + vi->ctrl->rss.indirection_table[i] = indir_val; + } + + vi->ctrl->rss.max_tx_vq = vi->curr_queue_pairs; + vi->ctrl->rss.hash_key_length = vi->rss_key_size; + + netdev_rss_key_fill(vi->ctrl->rss.key, vi->rss_key_size); +} + +static void virtnet_get_hashflow(const struct virtnet_info *vi, struct ethtool_rxnfc *info) +{ + info->data = 0; + switch (info->flow_type) { + case TCP_V4_FLOW: + if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_TCPv4) { + info->data = RXH_IP_SRC | RXH_IP_DST | + RXH_L4_B_0_1 | RXH_L4_B_2_3; + } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv4) { + info->data = RXH_IP_SRC | RXH_IP_DST; + } + break; + case TCP_V6_FLOW: + if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_TCPv6) { + info->data = RXH_IP_SRC | RXH_IP_DST | + RXH_L4_B_0_1 | RXH_L4_B_2_3; + } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv6) { + info->data = RXH_IP_SRC | RXH_IP_DST; + } + break; + case UDP_V4_FLOW: + if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_UDPv4) { + info->data = RXH_IP_SRC | RXH_IP_DST | + RXH_L4_B_0_1 | RXH_L4_B_2_3; + } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv4) { + info->data = RXH_IP_SRC | RXH_IP_DST; + } + break; + case UDP_V6_FLOW: + if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_UDPv6) { + info->data = RXH_IP_SRC | RXH_IP_DST | + RXH_L4_B_0_1 | RXH_L4_B_2_3; + } else if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv6) { + info->data = RXH_IP_SRC | RXH_IP_DST; + } + break; + case IPV4_FLOW: + if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv4) + info->data = RXH_IP_SRC | RXH_IP_DST; + + break; + case IPV6_FLOW: + if (vi->rss_hash_types_saved & VIRTIO_NET_RSS_HASH_TYPE_IPv6) + info->data = RXH_IP_SRC | RXH_IP_DST; + + break; + default: + info->data = 0; + break; + } +} + +static bool virtnet_set_hashflow(struct virtnet_info *vi, struct ethtool_rxnfc *info) +{ + u32 new_hashtypes = vi->rss_hash_types_saved; + bool is_disable = info->data & RXH_DISCARD; + bool is_l4 = info->data == (RXH_IP_SRC | RXH_IP_DST | RXH_L4_B_0_1 | RXH_L4_B_2_3); + + /* supports only 'sd', 'sdfn' and 'r' */ + if (!((info->data == (RXH_IP_SRC | RXH_IP_DST)) | is_l4 | is_disable)) + return false; + + switch (info->flow_type) { + case TCP_V4_FLOW: + new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv4 | VIRTIO_NET_RSS_HASH_TYPE_TCPv4); + if (!is_disable) + new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv4 + | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_TCPv4 : 0); + break; + case UDP_V4_FLOW: + new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv4 | VIRTIO_NET_RSS_HASH_TYPE_UDPv4); + if (!is_disable) + new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv4 + | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_UDPv4 : 0); + break; + case IPV4_FLOW: + new_hashtypes &= ~VIRTIO_NET_RSS_HASH_TYPE_IPv4; + if (!is_disable) + new_hashtypes = VIRTIO_NET_RSS_HASH_TYPE_IPv4; + break; + case TCP_V6_FLOW: + new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv6 | VIRTIO_NET_RSS_HASH_TYPE_TCPv6); + if (!is_disable) + new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv6 + | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_TCPv6 : 0); + break; + case UDP_V6_FLOW: + new_hashtypes &= ~(VIRTIO_NET_RSS_HASH_TYPE_IPv6 | VIRTIO_NET_RSS_HASH_TYPE_UDPv6); + if (!is_disable) + new_hashtypes |= VIRTIO_NET_RSS_HASH_TYPE_IPv6 + | (is_l4 ? VIRTIO_NET_RSS_HASH_TYPE_UDPv6 : 0); + break; + case IPV6_FLOW: + new_hashtypes &= ~VIRTIO_NET_RSS_HASH_TYPE_IPv6; + if (!is_disable) + new_hashtypes = VIRTIO_NET_RSS_HASH_TYPE_IPv6; + break; + default: + /* unsupported flow */ + return false; + } + + /* if unsupported hashtype was set */ + if (new_hashtypes != (new_hashtypes & vi->rss_hash_types_supported)) + return false; + + if (new_hashtypes != vi->rss_hash_types_saved) { + vi->rss_hash_types_saved = new_hashtypes; + vi->ctrl->rss.hash_types = vi->rss_hash_types_saved; + if (vi->dev->features & NETIF_F_RXHASH) + return virtnet_commit_rss_command(vi); + } + + return true; +} static void virtnet_get_drvinfo(struct net_device *dev, struct ethtool_drvinfo *info) @@ -2411,6 +2638,92 @@ static void virtnet_update_settings(struct virtnet_info *vi) vi->duplex = duplex; } +static u32 virtnet_get_rxfh_key_size(struct net_device *dev) +{ + return ((struct virtnet_info *)netdev_priv(dev))->rss_key_size; +} + +static u32 virtnet_get_rxfh_indir_size(struct net_device *dev) +{ + return ((struct virtnet_info *)netdev_priv(dev))->rss_indir_table_size; +} + +static int virtnet_get_rxfh(struct net_device *dev, u32 *indir, u8 *key, u8 *hfunc) +{ + struct virtnet_info *vi = netdev_priv(dev); + int i; + + if (indir) { + for (i = 0; i < vi->rss_indir_table_size; ++i) + indir[i] = vi->ctrl->rss.indirection_table[i]; + } + + if (key) + memcpy(key, vi->ctrl->rss.key, vi->rss_key_size); + + if (hfunc) + *hfunc = ETH_RSS_HASH_TOP; + + return 0; +} + +static int virtnet_set_rxfh(struct net_device *dev, const u32 *indir, const u8 *key, const u8 hfunc) +{ + struct virtnet_info *vi = netdev_priv(dev); + int i; + + if (hfunc != ETH_RSS_HASH_NO_CHANGE && hfunc != ETH_RSS_HASH_TOP) + return -EOPNOTSUPP; + + if (indir) { + for (i = 0; i < vi->rss_indir_table_size; ++i) + vi->ctrl->rss.indirection_table[i] = indir[i]; + } + if (key) + memcpy(vi->ctrl->rss.key, key, vi->rss_key_size); + + virtnet_commit_rss_command(vi); + + return 0; +} + +static int virtnet_get_rxnfc(struct net_device *dev, struct ethtool_rxnfc *info, u32 *rule_locs) +{ + struct virtnet_info *vi = netdev_priv(dev); + int rc = 0; + + switch (info->cmd) { + case ETHTOOL_GRXRINGS: + info->data = vi->curr_queue_pairs; + break; + case ETHTOOL_GRXFH: + virtnet_get_hashflow(vi, info); + break; + default: + rc = -EOPNOTSUPP; + } + + return rc; +} + +static int virtnet_set_rxnfc(struct net_device *dev, struct ethtool_rxnfc *info) +{ + struct virtnet_info *vi = netdev_priv(dev); + int rc = 0; + + switch (info->cmd) { + case ETHTOOL_SRXFH: + if (!virtnet_set_hashflow(vi, info)) + rc = -EINVAL; + + break; + default: + rc = -EOPNOTSUPP; + } + + return rc; +} + static const struct ethtool_ops virtnet_ethtool_ops = { .supported_coalesce_params = ETHTOOL_COALESCE_MAX_FRAMES, .get_drvinfo = virtnet_get_drvinfo, @@ -2426,6 +2739,12 @@ static const struct ethtool_ops virtnet_ethtool_ops = { .set_link_ksettings = virtnet_set_link_ksettings, .set_coalesce = virtnet_set_coalesce, .get_coalesce = virtnet_get_coalesce, + .get_rxfh_key_size = virtnet_get_rxfh_key_size, + .get_rxfh_indir_size = virtnet_get_rxfh_indir_size, + .get_rxfh = virtnet_get_rxfh, + .set_rxfh = virtnet_set_rxfh, + .get_rxnfc = virtnet_get_rxnfc, + .set_rxnfc = virtnet_set_rxnfc, }; static void virtnet_freeze_down(struct virtio_device *vdev) @@ -2678,6 +2997,16 @@ static int virtnet_set_features(struct net_device *dev, vi->guest_offloads = offloads; } + if ((dev->features ^ features) & NETIF_F_RXHASH) { + if (features & NETIF_F_RXHASH) + vi->ctrl->rss.hash_types = vi->rss_hash_types_saved; + else + vi->ctrl->rss.hash_types = VIRTIO_NET_HASH_REPORT_NONE; + + if (!virtnet_commit_rss_command(vi)) + return -EINVAL; + } + return 0; } @@ -2851,7 +3180,7 @@ static void virtnet_del_vqs(struct virtnet_info *vi) */ static unsigned int mergeable_min_buf_len(struct virtnet_info *vi, struct virtqueue *vq) { - const unsigned int hdr_len = sizeof(struct virtio_net_hdr_mrg_rxbuf); + const unsigned int hdr_len = vi->hdr_len; unsigned int rq_size = virtqueue_get_vring_size(vq); unsigned int packet_len = vi->big_packets ? IP_MAX_MTU : vi->dev->max_mtu; unsigned int buf_len = hdr_len + ETH_HLEN + VLAN_HLEN + packet_len; @@ -3072,6 +3401,10 @@ static bool virtnet_validate_features(struct virtio_device *vdev) "VIRTIO_NET_F_CTRL_VQ") || VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_MQ, "VIRTIO_NET_F_CTRL_VQ") || VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_CTRL_MAC_ADDR, + "VIRTIO_NET_F_CTRL_VQ") || + VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_RSS, + "VIRTIO_NET_F_CTRL_VQ") || + VIRTNET_FAIL_ON(vdev, VIRTIO_NET_F_HASH_REPORT, "VIRTIO_NET_F_CTRL_VQ"))) { return false; } @@ -3112,13 +3445,14 @@ static int virtnet_probe(struct virtio_device *vdev) u16 max_queue_pairs; int mtu; - /* Find if host supports multiqueue virtio_net device */ - err = virtio_cread_feature(vdev, VIRTIO_NET_F_MQ, - struct virtio_net_config, - max_virtqueue_pairs, &max_queue_pairs); + /* Find if host supports multiqueue/rss virtio_net device */ + max_queue_pairs = 1; + if (virtio_has_feature(vdev, VIRTIO_NET_F_MQ) || virtio_has_feature(vdev, VIRTIO_NET_F_RSS)) + max_queue_pairs = + virtio_cread16(vdev, offsetof(struct virtio_net_config, max_virtqueue_pairs)); /* We need at least 2 queue's */ - if (err || max_queue_pairs < VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN || + if (max_queue_pairs < VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MIN || max_queue_pairs > VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX || !virtio_has_feature(vdev, VIRTIO_NET_F_CTRL_VQ)) max_queue_pairs = 1; @@ -3206,8 +3540,33 @@ static int virtnet_probe(struct virtio_device *vdev) if (virtio_has_feature(vdev, VIRTIO_NET_F_MRG_RXBUF)) vi->mergeable_rx_bufs = true; - if (virtio_has_feature(vdev, VIRTIO_NET_F_MRG_RXBUF) || - virtio_has_feature(vdev, VIRTIO_F_VERSION_1)) + if (virtio_has_feature(vdev, VIRTIO_NET_F_HASH_REPORT)) + vi->has_rss_hash_report = true; + + if (virtio_has_feature(vdev, VIRTIO_NET_F_RSS)) + vi->has_rss = true; + + if (vi->has_rss || vi->has_rss_hash_report) { + vi->rss_indir_table_size = + virtio_cread16(vdev, offsetof(struct virtio_net_config, + rss_max_indirection_table_length)); + vi->rss_key_size = + virtio_cread8(vdev, offsetof(struct virtio_net_config, rss_max_key_size)); + + vi->rss_hash_types_supported = + virtio_cread32(vdev, offsetof(struct virtio_net_config, supported_hash_types)); + vi->rss_hash_types_supported &= + ~(VIRTIO_NET_RSS_HASH_TYPE_IP_EX | + VIRTIO_NET_RSS_HASH_TYPE_TCP_EX | + VIRTIO_NET_RSS_HASH_TYPE_UDP_EX); + + dev->hw_features |= NETIF_F_RXHASH; + } + + if (vi->has_rss_hash_report) + vi->hdr_len = sizeof(struct virtio_net_hdr_v1_hash); + else if (virtio_has_feature(vdev, VIRTIO_NET_F_MRG_RXBUF) || + virtio_has_feature(vdev, VIRTIO_F_VERSION_1)) vi->hdr_len = sizeof(struct virtio_net_hdr_mrg_rxbuf); else vi->hdr_len = sizeof(struct virtio_net_hdr); @@ -3274,6 +3633,9 @@ static int virtnet_probe(struct virtio_device *vdev) } } + if (vi->has_rss || vi->has_rss_hash_report) + virtnet_init_default_rss(vi); + err = register_netdev(dev); if (err) { pr_debug("virtio_net: registering device failed\n"); @@ -3405,7 +3767,8 @@ static struct virtio_device_id id_table[] = { VIRTIO_NET_F_GUEST_ANNOUNCE, VIRTIO_NET_F_MQ, \ VIRTIO_NET_F_CTRL_MAC_ADDR, \ VIRTIO_NET_F_MTU, VIRTIO_NET_F_CTRL_GUEST_OFFLOADS, \ - VIRTIO_NET_F_SPEED_DUPLEX, VIRTIO_NET_F_STANDBY + VIRTIO_NET_F_SPEED_DUPLEX, VIRTIO_NET_F_STANDBY, \ + VIRTIO_NET_F_RSS, VIRTIO_NET_F_HASH_REPORT static unsigned int features[] = { VIRTNET_FEATURES, diff --git a/drivers/vdpa/ifcvf/ifcvf_base.c b/drivers/vdpa/ifcvf/ifcvf_base.c index 7d41dfe48ade..48c4dadb0c7c 100644 --- a/drivers/vdpa/ifcvf/ifcvf_base.c +++ b/drivers/vdpa/ifcvf/ifcvf_base.c @@ -10,45 +10,29 @@ #include "ifcvf_base.h" -static inline u8 ifc_ioread8(u8 __iomem *addr) -{ - return ioread8(addr); -} -static inline u16 ifc_ioread16 (__le16 __iomem *addr) +struct ifcvf_adapter *vf_to_adapter(struct ifcvf_hw *hw) { - return ioread16(addr); + return container_of(hw, struct ifcvf_adapter, vf); } -static inline u32 ifc_ioread32(__le32 __iomem *addr) +u16 ifcvf_set_vq_vector(struct ifcvf_hw *hw, u16 qid, int vector) { - return ioread32(addr); -} + struct virtio_pci_common_cfg __iomem *cfg = hw->common_cfg; -static inline void ifc_iowrite8(u8 value, u8 __iomem *addr) -{ - iowrite8(value, addr); -} + vp_iowrite16(qid, &cfg->queue_select); + vp_iowrite16(vector, &cfg->queue_msix_vector); -static inline void ifc_iowrite16(u16 value, __le16 __iomem *addr) -{ - iowrite16(value, addr); + return vp_ioread16(&cfg->queue_msix_vector); } -static inline void ifc_iowrite32(u32 value, __le32 __iomem *addr) +u16 ifcvf_set_config_vector(struct ifcvf_hw *hw, int vector) { - iowrite32(value, addr); -} + struct virtio_pci_common_cfg __iomem *cfg = hw->common_cfg; -static void ifc_iowrite64_twopart(u64 val, - __le32 __iomem *lo, __le32 __iomem *hi) -{ - ifc_iowrite32((u32)val, lo); - ifc_iowrite32(val >> 32, hi); -} + cfg = hw->common_cfg; + vp_iowrite16(vector, &cfg->msix_config); -struct ifcvf_adapter *vf_to_adapter(struct ifcvf_hw *hw) -{ - return container_of(hw, struct ifcvf_adapter, vf); + return vp_ioread16(&cfg->msix_config); } static void __iomem *get_cap_addr(struct ifcvf_hw *hw, @@ -158,15 +142,16 @@ next: return -EIO; } - hw->nr_vring = ifc_ioread16(&hw->common_cfg->num_queues); + hw->nr_vring = vp_ioread16(&hw->common_cfg->num_queues); for (i = 0; i < hw->nr_vring; i++) { - ifc_iowrite16(i, &hw->common_cfg->queue_select); - notify_off = ifc_ioread16(&hw->common_cfg->queue_notify_off); + vp_iowrite16(i, &hw->common_cfg->queue_select); + notify_off = vp_ioread16(&hw->common_cfg->queue_notify_off); hw->vring[i].notify_addr = hw->notify_base + notify_off * hw->notify_off_multiplier; hw->vring[i].notify_pa = hw->notify_base_pa + notify_off * hw->notify_off_multiplier; + hw->vring[i].irq = -EINVAL; } hw->lm_cfg = hw->base[IFCVF_LM_BAR]; @@ -176,17 +161,20 @@ next: hw->common_cfg, hw->notify_base, hw->isr, hw->dev_cfg, hw->notify_off_multiplier); + hw->vqs_reused_irq = -EINVAL; + hw->config_irq = -EINVAL; + return 0; } u8 ifcvf_get_status(struct ifcvf_hw *hw) { - return ifc_ioread8(&hw->common_cfg->device_status); + return vp_ioread8(&hw->common_cfg->device_status); } void ifcvf_set_status(struct ifcvf_hw *hw, u8 status) { - ifc_iowrite8(status, &hw->common_cfg->device_status); + vp_iowrite8(status, &hw->common_cfg->device_status); } void ifcvf_reset(struct ifcvf_hw *hw) @@ -214,11 +202,11 @@ u64 ifcvf_get_hw_features(struct ifcvf_hw *hw) u32 features_lo, features_hi; u64 features; - ifc_iowrite32(0, &cfg->device_feature_select); - features_lo = ifc_ioread32(&cfg->device_feature); + vp_iowrite32(0, &cfg->device_feature_select); + features_lo = vp_ioread32(&cfg->device_feature); - ifc_iowrite32(1, &cfg->device_feature_select); - features_hi = ifc_ioread32(&cfg->device_feature); + vp_iowrite32(1, &cfg->device_feature_select); + features_hi = vp_ioread32(&cfg->device_feature); features = ((u64)features_hi << 32) | features_lo; @@ -271,12 +259,12 @@ void ifcvf_read_dev_config(struct ifcvf_hw *hw, u64 offset, WARN_ON(offset + length > hw->config_size); do { - old_gen = ifc_ioread8(&hw->common_cfg->config_generation); + old_gen = vp_ioread8(&hw->common_cfg->config_generation); p = dst; for (i = 0; i < length; i++) - *p++ = ifc_ioread8(hw->dev_cfg + offset + i); + *p++ = vp_ioread8(hw->dev_cfg + offset + i); - new_gen = ifc_ioread8(&hw->common_cfg->config_generation); + new_gen = vp_ioread8(&hw->common_cfg->config_generation); } while (old_gen != new_gen); } @@ -289,18 +277,18 @@ void ifcvf_write_dev_config(struct ifcvf_hw *hw, u64 offset, p = src; WARN_ON(offset + length > hw->config_size); for (i = 0; i < length; i++) - ifc_iowrite8(*p++, hw->dev_cfg + offset + i); + vp_iowrite8(*p++, hw->dev_cfg + offset + i); } static void ifcvf_set_features(struct ifcvf_hw *hw, u64 features) { struct virtio_pci_common_cfg __iomem *cfg = hw->common_cfg; - ifc_iowrite32(0, &cfg->guest_feature_select); - ifc_iowrite32((u32)features, &cfg->guest_feature); + vp_iowrite32(0, &cfg->guest_feature_select); + vp_iowrite32((u32)features, &cfg->guest_feature); - ifc_iowrite32(1, &cfg->guest_feature_select); - ifc_iowrite32(features >> 32, &cfg->guest_feature); + vp_iowrite32(1, &cfg->guest_feature_select); + vp_iowrite32(features >> 32, &cfg->guest_feature); } static int ifcvf_config_features(struct ifcvf_hw *hw) @@ -329,7 +317,7 @@ u16 ifcvf_get_vq_state(struct ifcvf_hw *hw, u16 qid) ifcvf_lm = (struct ifcvf_lm_cfg __iomem *)hw->lm_cfg; q_pair_id = qid / hw->nr_vring; avail_idx_addr = &ifcvf_lm->vring_lm_cfg[q_pair_id].idx_addr[qid % 2]; - last_avail_idx = ifc_ioread16(avail_idx_addr); + last_avail_idx = vp_ioread16(avail_idx_addr); return last_avail_idx; } @@ -344,7 +332,7 @@ int ifcvf_set_vq_state(struct ifcvf_hw *hw, u16 qid, u16 num) q_pair_id = qid / hw->nr_vring; avail_idx_addr = &ifcvf_lm->vring_lm_cfg[q_pair_id].idx_addr[qid % 2]; hw->vring[qid].last_avail_idx = num; - ifc_iowrite16(num, avail_idx_addr); + vp_iowrite16(num, avail_idx_addr); return 0; } @@ -352,41 +340,23 @@ int ifcvf_set_vq_state(struct ifcvf_hw *hw, u16 qid, u16 num) static int ifcvf_hw_enable(struct ifcvf_hw *hw) { struct virtio_pci_common_cfg __iomem *cfg; - struct ifcvf_adapter *ifcvf; u32 i; - ifcvf = vf_to_adapter(hw); cfg = hw->common_cfg; - ifc_iowrite16(IFCVF_MSI_CONFIG_OFF, &cfg->msix_config); - - if (ifc_ioread16(&cfg->msix_config) == VIRTIO_MSI_NO_VECTOR) { - IFCVF_ERR(ifcvf->pdev, "No msix vector for device config\n"); - return -EINVAL; - } - for (i = 0; i < hw->nr_vring; i++) { if (!hw->vring[i].ready) break; - ifc_iowrite16(i, &cfg->queue_select); - ifc_iowrite64_twopart(hw->vring[i].desc, &cfg->queue_desc_lo, + vp_iowrite16(i, &cfg->queue_select); + vp_iowrite64_twopart(hw->vring[i].desc, &cfg->queue_desc_lo, &cfg->queue_desc_hi); - ifc_iowrite64_twopart(hw->vring[i].avail, &cfg->queue_avail_lo, + vp_iowrite64_twopart(hw->vring[i].avail, &cfg->queue_avail_lo, &cfg->queue_avail_hi); - ifc_iowrite64_twopart(hw->vring[i].used, &cfg->queue_used_lo, + vp_iowrite64_twopart(hw->vring[i].used, &cfg->queue_used_lo, &cfg->queue_used_hi); - ifc_iowrite16(hw->vring[i].size, &cfg->queue_size); - ifc_iowrite16(i + IFCVF_MSI_QUEUE_OFF, &cfg->queue_msix_vector); - - if (ifc_ioread16(&cfg->queue_msix_vector) == - VIRTIO_MSI_NO_VECTOR) { - IFCVF_ERR(ifcvf->pdev, - "No msix vector for queue %u\n", i); - return -EINVAL; - } - + vp_iowrite16(hw->vring[i].size, &cfg->queue_size); ifcvf_set_vq_state(hw, i, hw->vring[i].last_avail_idx); - ifc_iowrite16(1, &cfg->queue_enable); + vp_iowrite16(1, &cfg->queue_enable); } return 0; @@ -394,18 +364,12 @@ static int ifcvf_hw_enable(struct ifcvf_hw *hw) static void ifcvf_hw_disable(struct ifcvf_hw *hw) { - struct virtio_pci_common_cfg __iomem *cfg; u32 i; - cfg = hw->common_cfg; - ifc_iowrite16(VIRTIO_MSI_NO_VECTOR, &cfg->msix_config); - + ifcvf_set_config_vector(hw, VIRTIO_MSI_NO_VECTOR); for (i = 0; i < hw->nr_vring; i++) { - ifc_iowrite16(i, &cfg->queue_select); - ifc_iowrite16(VIRTIO_MSI_NO_VECTOR, &cfg->queue_msix_vector); + ifcvf_set_vq_vector(hw, i, VIRTIO_MSI_NO_VECTOR); } - - ifc_ioread16(&cfg->queue_msix_vector); } int ifcvf_start_hw(struct ifcvf_hw *hw) @@ -433,5 +397,5 @@ void ifcvf_stop_hw(struct ifcvf_hw *hw) void ifcvf_notify_queue(struct ifcvf_hw *hw, u16 qid) { - ifc_iowrite16(qid, hw->vring[qid].notify_addr); + vp_iowrite16(qid, hw->vring[qid].notify_addr); } diff --git a/drivers/vdpa/ifcvf/ifcvf_base.h b/drivers/vdpa/ifcvf/ifcvf_base.h index c486873f370a..115b61f4924b 100644 --- a/drivers/vdpa/ifcvf/ifcvf_base.h +++ b/drivers/vdpa/ifcvf/ifcvf_base.h @@ -14,6 +14,7 @@ #include <linux/pci.h> #include <linux/pci_regs.h> #include <linux/vdpa.h> +#include <linux/virtio_pci_modern.h> #include <uapi/linux/virtio_net.h> #include <uapi/linux/virtio_blk.h> #include <uapi/linux/virtio_config.h> @@ -27,8 +28,6 @@ #define IFCVF_QUEUE_ALIGNMENT PAGE_SIZE #define IFCVF_QUEUE_MAX 32768 -#define IFCVF_MSI_CONFIG_OFF 0 -#define IFCVF_MSI_QUEUE_OFF 1 #define IFCVF_PCI_MAX_RESOURCE 6 #define IFCVF_LM_CFG_SIZE 0x40 @@ -42,6 +41,13 @@ #define ifcvf_private_to_vf(adapter) \ (&((struct ifcvf_adapter *)adapter)->vf) +/* all vqs and config interrupt has its own vector */ +#define MSIX_VECTOR_PER_VQ_AND_CONFIG 1 +/* all vqs share a vector, and config interrupt has a separate vector */ +#define MSIX_VECTOR_SHARED_VQ_AND_CONFIG 2 +/* all vqs and config interrupt share a vector */ +#define MSIX_VECTOR_DEV_SHARED 3 + struct vring_info { u64 desc; u64 avail; @@ -60,25 +66,27 @@ struct ifcvf_hw { u8 __iomem *isr; /* Live migration */ u8 __iomem *lm_cfg; - u16 nr_vring; /* Notification bar number */ u8 notify_bar; + u8 msix_vector_status; + /* virtio-net or virtio-blk device config size */ + u32 config_size; /* Notificaiton bar address */ void __iomem *notify_base; phys_addr_t notify_base_pa; u32 notify_off_multiplier; + u32 dev_type; u64 req_features; u64 hw_features; - u32 dev_type; struct virtio_pci_common_cfg __iomem *common_cfg; void __iomem *dev_cfg; struct vring_info vring[IFCVF_MAX_QUEUES]; void __iomem * const *base; char config_msix_name[256]; struct vdpa_callback config_cb; - unsigned int config_irq; - /* virtio-net or virtio-blk device config size */ - u32 config_size; + int config_irq; + int vqs_reused_irq; + u16 nr_vring; }; struct ifcvf_adapter { @@ -123,4 +131,6 @@ int ifcvf_set_vq_state(struct ifcvf_hw *hw, u16 qid, u16 num); struct ifcvf_adapter *vf_to_adapter(struct ifcvf_hw *hw); int ifcvf_probed_virtio_net(struct ifcvf_hw *hw); u32 ifcvf_get_config_size(struct ifcvf_hw *hw); +u16 ifcvf_set_vq_vector(struct ifcvf_hw *hw, u16 qid, int vector); +u16 ifcvf_set_config_vector(struct ifcvf_hw *hw, int vector); #endif /* _IFCVF_H_ */ diff --git a/drivers/vdpa/ifcvf/ifcvf_main.c b/drivers/vdpa/ifcvf/ifcvf_main.c index d1a6b5ab543c..4366320fb68d 100644 --- a/drivers/vdpa/ifcvf/ifcvf_main.c +++ b/drivers/vdpa/ifcvf/ifcvf_main.c @@ -27,7 +27,7 @@ static irqreturn_t ifcvf_config_changed(int irq, void *arg) return IRQ_HANDLED; } -static irqreturn_t ifcvf_intr_handler(int irq, void *arg) +static irqreturn_t ifcvf_vq_intr_handler(int irq, void *arg) { struct vring_info *vring = arg; @@ -37,76 +37,324 @@ static irqreturn_t ifcvf_intr_handler(int irq, void *arg) return IRQ_HANDLED; } +static irqreturn_t ifcvf_vqs_reused_intr_handler(int irq, void *arg) +{ + struct ifcvf_hw *vf = arg; + struct vring_info *vring; + int i; + + for (i = 0; i < vf->nr_vring; i++) { + vring = &vf->vring[i]; + if (vring->cb.callback) + vring->cb.callback(vring->cb.private); + } + + return IRQ_HANDLED; +} + +static irqreturn_t ifcvf_dev_intr_handler(int irq, void *arg) +{ + struct ifcvf_hw *vf = arg; + u8 isr; + + isr = vp_ioread8(vf->isr); + if (isr & VIRTIO_PCI_ISR_CONFIG) + ifcvf_config_changed(irq, arg); + + return ifcvf_vqs_reused_intr_handler(irq, arg); +} + static void ifcvf_free_irq_vectors(void *data) { pci_free_irq_vectors(data); } -static void ifcvf_free_irq(struct ifcvf_adapter *adapter, int queues) +static void ifcvf_free_per_vq_irq(struct ifcvf_adapter *adapter) { struct pci_dev *pdev = adapter->pdev; struct ifcvf_hw *vf = &adapter->vf; int i; + for (i = 0; i < vf->nr_vring; i++) { + if (vf->vring[i].irq != -EINVAL) { + devm_free_irq(&pdev->dev, vf->vring[i].irq, &vf->vring[i]); + vf->vring[i].irq = -EINVAL; + } + } +} - for (i = 0; i < queues; i++) { - devm_free_irq(&pdev->dev, vf->vring[i].irq, &vf->vring[i]); - vf->vring[i].irq = -EINVAL; +static void ifcvf_free_vqs_reused_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + struct ifcvf_hw *vf = &adapter->vf; + + if (vf->vqs_reused_irq != -EINVAL) { + devm_free_irq(&pdev->dev, vf->vqs_reused_irq, vf); + vf->vqs_reused_irq = -EINVAL; } - devm_free_irq(&pdev->dev, vf->config_irq, vf); +} + +static void ifcvf_free_vq_irq(struct ifcvf_adapter *adapter) +{ + struct ifcvf_hw *vf = &adapter->vf; + + if (vf->msix_vector_status == MSIX_VECTOR_PER_VQ_AND_CONFIG) + ifcvf_free_per_vq_irq(adapter); + else + ifcvf_free_vqs_reused_irq(adapter); +} + +static void ifcvf_free_config_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + struct ifcvf_hw *vf = &adapter->vf; + + if (vf->config_irq == -EINVAL) + return; + + /* If the irq is shared by all vqs and the config interrupt, + * it is already freed in ifcvf_free_vq_irq, so here only + * need to free config irq when msix_vector_status != MSIX_VECTOR_DEV_SHARED + */ + if (vf->msix_vector_status != MSIX_VECTOR_DEV_SHARED) { + devm_free_irq(&pdev->dev, vf->config_irq, vf); + vf->config_irq = -EINVAL; + } +} + +static void ifcvf_free_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + + ifcvf_free_vq_irq(adapter); + ifcvf_free_config_irq(adapter); ifcvf_free_irq_vectors(pdev); } -static int ifcvf_request_irq(struct ifcvf_adapter *adapter) +/* ifcvf MSIX vectors allocator, this helper tries to allocate + * vectors for all virtqueues and the config interrupt. + * It returns the number of allocated vectors, negative + * return value when fails. + */ +static int ifcvf_alloc_vectors(struct ifcvf_adapter *adapter) { struct pci_dev *pdev = adapter->pdev; struct ifcvf_hw *vf = &adapter->vf; - int vector, i, ret, irq; - u16 max_intr; + int max_intr, ret; /* all queues and config interrupt */ max_intr = vf->nr_vring + 1; + ret = pci_alloc_irq_vectors(pdev, 1, max_intr, PCI_IRQ_MSIX | PCI_IRQ_AFFINITY); - ret = pci_alloc_irq_vectors(pdev, max_intr, - max_intr, PCI_IRQ_MSIX); if (ret < 0) { IFCVF_ERR(pdev, "Failed to alloc IRQ vectors\n"); return ret; } + if (ret < max_intr) + IFCVF_INFO(pdev, + "Requested %u vectors, however only %u allocated, lower performance\n", + max_intr, ret); + + return ret; +} + +static int ifcvf_request_per_vq_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + struct ifcvf_hw *vf = &adapter->vf; + int i, vector, ret, irq; + + vf->vqs_reused_irq = -EINVAL; + for (i = 0; i < vf->nr_vring; i++) { + snprintf(vf->vring[i].msix_name, 256, "ifcvf[%s]-%d\n", pci_name(pdev), i); + vector = i; + irq = pci_irq_vector(pdev, vector); + ret = devm_request_irq(&pdev->dev, irq, + ifcvf_vq_intr_handler, 0, + vf->vring[i].msix_name, + &vf->vring[i]); + if (ret) { + IFCVF_ERR(pdev, "Failed to request irq for vq %d\n", i); + goto err; + } + + vf->vring[i].irq = irq; + ret = ifcvf_set_vq_vector(vf, i, vector); + if (ret == VIRTIO_MSI_NO_VECTOR) { + IFCVF_ERR(pdev, "No msix vector for vq %u\n", i); + goto err; + } + } + + return 0; +err: + ifcvf_free_irq(adapter); + + return -EFAULT; +} + +static int ifcvf_request_vqs_reused_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + struct ifcvf_hw *vf = &adapter->vf; + int i, vector, ret, irq; + + vector = 0; + snprintf(vf->vring[0].msix_name, 256, "ifcvf[%s]-vqs-reused-irq\n", pci_name(pdev)); + irq = pci_irq_vector(pdev, vector); + ret = devm_request_irq(&pdev->dev, irq, + ifcvf_vqs_reused_intr_handler, 0, + vf->vring[0].msix_name, vf); + if (ret) { + IFCVF_ERR(pdev, "Failed to request reused irq for the device\n"); + goto err; + } + + vf->vqs_reused_irq = irq; + for (i = 0; i < vf->nr_vring; i++) { + vf->vring[i].irq = -EINVAL; + ret = ifcvf_set_vq_vector(vf, i, vector); + if (ret == VIRTIO_MSI_NO_VECTOR) { + IFCVF_ERR(pdev, "No msix vector for vq %u\n", i); + goto err; + } + } + + return 0; +err: + ifcvf_free_irq(adapter); + + return -EFAULT; +} + +static int ifcvf_request_dev_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + struct ifcvf_hw *vf = &adapter->vf; + int i, vector, ret, irq; + + vector = 0; + snprintf(vf->vring[0].msix_name, 256, "ifcvf[%s]-dev-irq\n", pci_name(pdev)); + irq = pci_irq_vector(pdev, vector); + ret = devm_request_irq(&pdev->dev, irq, + ifcvf_dev_intr_handler, 0, + vf->vring[0].msix_name, vf); + if (ret) { + IFCVF_ERR(pdev, "Failed to request irq for the device\n"); + goto err; + } + + vf->vqs_reused_irq = irq; + for (i = 0; i < vf->nr_vring; i++) { + vf->vring[i].irq = -EINVAL; + ret = ifcvf_set_vq_vector(vf, i, vector); + if (ret == VIRTIO_MSI_NO_VECTOR) { + IFCVF_ERR(pdev, "No msix vector for vq %u\n", i); + goto err; + } + } + + vf->config_irq = irq; + ret = ifcvf_set_config_vector(vf, vector); + if (ret == VIRTIO_MSI_NO_VECTOR) { + IFCVF_ERR(pdev, "No msix vector for device config\n"); + goto err; + } + + return 0; +err: + ifcvf_free_irq(adapter); + + return -EFAULT; + +} + +static int ifcvf_request_vq_irq(struct ifcvf_adapter *adapter) +{ + struct ifcvf_hw *vf = &adapter->vf; + int ret; + + if (vf->msix_vector_status == MSIX_VECTOR_PER_VQ_AND_CONFIG) + ret = ifcvf_request_per_vq_irq(adapter); + else + ret = ifcvf_request_vqs_reused_irq(adapter); + + return ret; +} + +static int ifcvf_request_config_irq(struct ifcvf_adapter *adapter) +{ + struct pci_dev *pdev = adapter->pdev; + struct ifcvf_hw *vf = &adapter->vf; + int config_vector, ret; + + if (vf->msix_vector_status == MSIX_VECTOR_DEV_SHARED) + return 0; + + if (vf->msix_vector_status == MSIX_VECTOR_PER_VQ_AND_CONFIG) + /* vector 0 ~ vf->nr_vring for vqs, num vf->nr_vring vector for config interrupt */ + config_vector = vf->nr_vring; + + if (vf->msix_vector_status == MSIX_VECTOR_SHARED_VQ_AND_CONFIG) + /* vector 0 for vqs and 1 for config interrupt */ + config_vector = 1; + snprintf(vf->config_msix_name, 256, "ifcvf[%s]-config\n", pci_name(pdev)); - vector = 0; - vf->config_irq = pci_irq_vector(pdev, vector); + vf->config_irq = pci_irq_vector(pdev, config_vector); ret = devm_request_irq(&pdev->dev, vf->config_irq, ifcvf_config_changed, 0, vf->config_msix_name, vf); if (ret) { IFCVF_ERR(pdev, "Failed to request config irq\n"); - return ret; + goto err; } - for (i = 0; i < vf->nr_vring; i++) { - snprintf(vf->vring[i].msix_name, 256, "ifcvf[%s]-%d\n", - pci_name(pdev), i); - vector = i + IFCVF_MSI_QUEUE_OFF; - irq = pci_irq_vector(pdev, vector); - ret = devm_request_irq(&pdev->dev, irq, - ifcvf_intr_handler, 0, - vf->vring[i].msix_name, - &vf->vring[i]); - if (ret) { - IFCVF_ERR(pdev, - "Failed to request irq for vq %d\n", i); - ifcvf_free_irq(adapter, i); + ret = ifcvf_set_config_vector(vf, config_vector); + if (ret == VIRTIO_MSI_NO_VECTOR) { + IFCVF_ERR(pdev, "No msix vector for device config\n"); + goto err; + } - return ret; - } + return 0; +err: + ifcvf_free_irq(adapter); - vf->vring[i].irq = irq; + return -EFAULT; +} + +static int ifcvf_request_irq(struct ifcvf_adapter *adapter) +{ + struct ifcvf_hw *vf = &adapter->vf; + int nvectors, ret, max_intr; + + nvectors = ifcvf_alloc_vectors(adapter); + if (nvectors <= 0) + return -EFAULT; + + vf->msix_vector_status = MSIX_VECTOR_PER_VQ_AND_CONFIG; + max_intr = vf->nr_vring + 1; + if (nvectors < max_intr) + vf->msix_vector_status = MSIX_VECTOR_SHARED_VQ_AND_CONFIG; + + if (nvectors == 1) { + vf->msix_vector_status = MSIX_VECTOR_DEV_SHARED; + ret = ifcvf_request_dev_irq(adapter); + + return ret; } + ret = ifcvf_request_vq_irq(adapter); + if (ret) + return ret; + + ret = ifcvf_request_config_irq(adapter); + + if (ret) + return ret; + return 0; } @@ -263,7 +511,7 @@ static int ifcvf_vdpa_reset(struct vdpa_device *vdpa_dev) if (status_old & VIRTIO_CONFIG_S_DRIVER_OK) { ifcvf_stop_datapath(adapter); - ifcvf_free_irq(adapter, vf->nr_vring); + ifcvf_free_irq(adapter); } ifcvf_reset_vring(adapter); @@ -348,7 +596,7 @@ static u32 ifcvf_vdpa_get_generation(struct vdpa_device *vdpa_dev) { struct ifcvf_hw *vf = vdpa_to_vf(vdpa_dev); - return ioread8(&vf->common_cfg->config_generation); + return vp_ioread8(&vf->common_cfg->config_generation); } static u32 ifcvf_vdpa_get_device_id(struct vdpa_device *vdpa_dev) @@ -410,7 +658,10 @@ static int ifcvf_vdpa_get_vq_irq(struct vdpa_device *vdpa_dev, { struct ifcvf_hw *vf = vdpa_to_vf(vdpa_dev); - return vf->vring[qid].irq; + if (vf->vqs_reused_irq < 0) + return vf->vring[qid].irq; + else + return -EINVAL; } static struct vdpa_notification_area ifcvf_get_vq_notification(struct vdpa_device *vdpa_dev, diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c index d0f91078600e..2f4fb09f1e89 100644 --- a/drivers/vdpa/mlx5/net/mlx5_vnet.c +++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c @@ -1475,7 +1475,7 @@ static virtio_net_ctrl_ack handle_ctrl_mac(struct mlx5_vdpa_dev *mvdev, u8 cmd) virtio_net_ctrl_ack status = VIRTIO_NET_ERR; struct mlx5_core_dev *pfmdev; size_t read; - u8 mac[ETH_ALEN]; + u8 mac[ETH_ALEN], mac_back[ETH_ALEN]; pfmdev = pci_get_drvdata(pci_physfn(mvdev->mdev->pdev)); switch (cmd) { @@ -1489,6 +1489,9 @@ static virtio_net_ctrl_ack handle_ctrl_mac(struct mlx5_vdpa_dev *mvdev, u8 cmd) break; } + if (is_zero_ether_addr(mac)) + break; + if (!is_zero_ether_addr(ndev->config.mac)) { if (mlx5_mpfs_del_mac(pfmdev, ndev->config.mac)) { mlx5_vdpa_warn(mvdev, "failed to delete old MAC %pM from MPFS table\n", @@ -1503,7 +1506,47 @@ static virtio_net_ctrl_ack handle_ctrl_mac(struct mlx5_vdpa_dev *mvdev, u8 cmd) break; } + /* backup the original mac address so that if failed to add the forward rules + * we could restore it + */ + memcpy(mac_back, ndev->config.mac, ETH_ALEN); + memcpy(ndev->config.mac, mac, ETH_ALEN); + + /* Need recreate the flow table entry, so that the packet could forward back + */ + remove_fwd_to_tir(ndev); + + if (add_fwd_to_tir(ndev)) { + mlx5_vdpa_warn(mvdev, "failed to insert forward rules, try to restore\n"); + + /* Although it hardly run here, we still need double check */ + if (is_zero_ether_addr(mac_back)) { + mlx5_vdpa_warn(mvdev, "restore mac failed: Original MAC is zero\n"); + break; + } + + /* Try to restore original mac address to MFPS table, and try to restore + * the forward rule entry. + */ + if (mlx5_mpfs_del_mac(pfmdev, ndev->config.mac)) { + mlx5_vdpa_warn(mvdev, "restore mac failed: delete MAC %pM from MPFS table failed\n", + ndev->config.mac); + } + + if (mlx5_mpfs_add_mac(pfmdev, mac_back)) { + mlx5_vdpa_warn(mvdev, "restore mac failed: insert old MAC %pM into MPFS table failed\n", + mac_back); + } + + memcpy(ndev->config.mac, mac_back, ETH_ALEN); + + if (add_fwd_to_tir(ndev)) + mlx5_vdpa_warn(mvdev, "restore forward rules failed: insert forward rules failed\n"); + + break; + } + status = VIRTIO_NET_OK; break; @@ -1669,7 +1712,7 @@ static void mlx5_vdpa_kick_vq(struct vdpa_device *vdev, u16 idx) return; if (unlikely(is_ctrl_vq_idx(mvdev, idx))) { - if (!mvdev->cvq.ready) + if (!mvdev->wq || !mvdev->cvq.ready) return; wqent = kzalloc(sizeof(*wqent), GFP_ATOMIC); @@ -2565,6 +2608,28 @@ static int event_handler(struct notifier_block *nb, unsigned long event, void *p return ret; } +static int config_func_mtu(struct mlx5_core_dev *mdev, u16 mtu) +{ + int inlen = MLX5_ST_SZ_BYTES(modify_nic_vport_context_in); + void *in; + int err; + + in = kvzalloc(inlen, GFP_KERNEL); + if (!in) + return -ENOMEM; + + MLX5_SET(modify_nic_vport_context_in, in, field_select.mtu, 1); + MLX5_SET(modify_nic_vport_context_in, in, nic_vport_context.mtu, + mtu + MLX5V_ETH_HARD_MTU); + MLX5_SET(modify_nic_vport_context_in, in, opcode, + MLX5_CMD_OP_MODIFY_NIC_VPORT_CONTEXT); + + err = mlx5_cmd_exec_in(mdev, modify_nic_vport_context, in); + + kvfree(in); + return err; +} + static int mlx5_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name, const struct vdpa_dev_set_config *add_config) { @@ -2624,6 +2689,13 @@ static int mlx5_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name, init_mvqs(ndev); mutex_init(&ndev->reslock); config = &ndev->config; + + if (add_config->mask & BIT_ULL(VDPA_ATTR_DEV_NET_CFG_MTU)) { + err = config_func_mtu(mdev, add_config->net.mtu); + if (err) + goto err_mtu; + } + err = query_mtu(mdev, &mtu); if (err) goto err_mtu; @@ -2707,9 +2779,12 @@ static void mlx5_vdpa_dev_del(struct vdpa_mgmt_dev *v_mdev, struct vdpa_device * struct mlx5_vdpa_mgmtdev *mgtdev = container_of(v_mdev, struct mlx5_vdpa_mgmtdev, mgtdev); struct mlx5_vdpa_dev *mvdev = to_mvdev(dev); struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev); + struct workqueue_struct *wq; mlx5_notifier_unregister(mvdev->mdev, &ndev->nb); - destroy_workqueue(mvdev->wq); + wq = mvdev->wq; + mvdev->wq = NULL; + destroy_workqueue(wq); _vdpa_unregister_device(dev); mgtdev->ndev = NULL; } @@ -2741,7 +2816,8 @@ static int mlx5v_probe(struct auxiliary_device *adev, mgtdev->mgtdev.device = mdev->device; mgtdev->mgtdev.id_table = id_table; mgtdev->mgtdev.config_attr_mask = BIT_ULL(VDPA_ATTR_DEV_NET_CFG_MACADDR) | - BIT_ULL(VDPA_ATTR_DEV_NET_CFG_MAX_VQP); + BIT_ULL(VDPA_ATTR_DEV_NET_CFG_MAX_VQP) | + BIT_ULL(VDPA_ATTR_DEV_NET_CFG_MTU); mgtdev->mgtdev.max_supported_vqs = MLX5_CAP_DEV_VDPA_EMULATION(mdev, max_num_virtio_queues) + 1; mgtdev->mgtdev.supported_features = get_supported_features(mdev); diff --git a/drivers/vdpa/vdpa.c b/drivers/vdpa/vdpa.c index 1ea525433a5c..2b75c00b1005 100644 --- a/drivers/vdpa/vdpa.c +++ b/drivers/vdpa/vdpa.c @@ -232,7 +232,7 @@ static int vdpa_name_match(struct device *dev, const void *data) return (strcmp(dev_name(&vdev->dev), data) == 0); } -static int __vdpa_register_device(struct vdpa_device *vdev, int nvqs) +static int __vdpa_register_device(struct vdpa_device *vdev, u32 nvqs) { struct device *dev; @@ -257,7 +257,7 @@ static int __vdpa_register_device(struct vdpa_device *vdev, int nvqs) * * Return: Returns an error when fail to add device to vDPA bus */ -int _vdpa_register_device(struct vdpa_device *vdev, int nvqs) +int _vdpa_register_device(struct vdpa_device *vdev, u32 nvqs) { if (!vdev->mdev) return -EINVAL; @@ -274,7 +274,7 @@ EXPORT_SYMBOL_GPL(_vdpa_register_device); * * Return: Returns an error when fail to add to vDPA bus */ -int vdpa_register_device(struct vdpa_device *vdev, int nvqs) +int vdpa_register_device(struct vdpa_device *vdev, u32 nvqs) { int err; diff --git a/drivers/vhost/iotlb.c b/drivers/vhost/iotlb.c index 40b098320b2a..5829cf2d0552 100644 --- a/drivers/vhost/iotlb.c +++ b/drivers/vhost/iotlb.c @@ -62,8 +62,12 @@ int vhost_iotlb_add_range_ctx(struct vhost_iotlb *iotlb, */ if (start == 0 && last == ULONG_MAX) { u64 mid = last / 2; + int err = vhost_iotlb_add_range_ctx(iotlb, start, mid, addr, + perm, opaque); + + if (err) + return err; - vhost_iotlb_add_range_ctx(iotlb, start, mid, addr, perm, opaque); addr += mid + 1; start = mid + 1; } diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c index ec5249e8c32d..4c2f0bd06285 100644 --- a/drivers/vhost/vdpa.c +++ b/drivers/vhost/vdpa.c @@ -42,7 +42,7 @@ struct vhost_vdpa { struct device dev; struct cdev cdev; atomic_t opened; - int nvqs; + u32 nvqs; int virtio_id; int minor; struct eventfd_ctx *config_ctx; @@ -97,8 +97,11 @@ static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid) return; irq = ops->get_vq_irq(vdpa, qid); + if (irq < 0) + return; + irq_bypass_unregister_producer(&vq->call_ctx.producer); - if (!vq->call_ctx.ctx || irq < 0) + if (!vq->call_ctx.ctx) return; vq->call_ctx.producer.token = vq->call_ctx.ctx; @@ -158,7 +161,8 @@ static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp) struct vdpa_device *vdpa = v->vdpa; const struct vdpa_config_ops *ops = vdpa->config; u8 status, status_old; - int ret, nvqs = v->nvqs; + u32 nvqs = v->nvqs; + int ret; u16 i; if (copy_from_user(&status, statusp, sizeof(status))) @@ -355,6 +359,30 @@ static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp) return 0; } +static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp) +{ + struct vdpa_device *vdpa = v->vdpa; + const struct vdpa_config_ops *ops = vdpa->config; + u32 size; + + size = ops->get_config_size(vdpa); + + if (copy_to_user(argp, &size, sizeof(size))) + return -EFAULT; + + return 0; +} + +static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp) +{ + struct vdpa_device *vdpa = v->vdpa; + + if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs))) + return -EFAULT; + + return 0; +} + static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd, void __user *argp) { @@ -492,6 +520,12 @@ static long vhost_vdpa_unlocked_ioctl(struct file *filep, case VHOST_VDPA_GET_IOVA_RANGE: r = vhost_vdpa_get_iova_range(v, argp); break; + case VHOST_VDPA_GET_CONFIG_SIZE: + r = vhost_vdpa_get_config_size(v, argp); + break; + case VHOST_VDPA_GET_VQS_COUNT: + r = vhost_vdpa_get_vqs_count(v, argp); + break; default: r = vhost_dev_ioctl(&v->vdev, cmd, argp); if (r == -ENOIOCTLCMD) @@ -948,7 +982,8 @@ static int vhost_vdpa_open(struct inode *inode, struct file *filep) struct vhost_vdpa *v; struct vhost_dev *dev; struct vhost_virtqueue **vqs; - int nvqs, i, r, opened; + int r, opened; + u32 i, nvqs; v = container_of(inode->i_cdev, struct vhost_vdpa, cdev); @@ -1001,7 +1036,7 @@ err: static void vhost_vdpa_clean_irq(struct vhost_vdpa *v) { - int i; + u32 i; for (i = 0; i < v->nvqs; i++) vhost_vdpa_unsetup_vq_irq(v, i); diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 1768362115c6..d02173fb290c 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -2550,8 +2550,9 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) &vq->avail->idx, r); return false; } + vq->avail_idx = vhost16_to_cpu(vq, avail_idx); - return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx; + return vq->avail_idx != vq->last_avail_idx; } EXPORT_SYMBOL_GPL(vhost_enable_notify); diff --git a/drivers/virtio/Kconfig b/drivers/virtio/Kconfig index 492fc26f0b65..b5adf6abd241 100644 --- a/drivers/virtio/Kconfig +++ b/drivers/virtio/Kconfig @@ -105,7 +105,7 @@ config VIRTIO_BALLOON config VIRTIO_MEM tristate "Virtio mem driver" - depends on X86_64 + depends on X86_64 || ARM64 depends on VIRTIO depends on MEMORY_HOTPLUG depends on MEMORY_HOTREMOVE @@ -115,8 +115,9 @@ config VIRTIO_MEM This driver provides access to virtio-mem paravirtualized memory devices, allowing to hotplug and hotunplug memory. - This driver was only tested under x86-64, but should theoretically - work on all architectures that support memory hotplug and hotremove. + This driver was only tested under x86-64 and arm64, but should + theoretically work on all architectures that support memory hotplug + and hotremove. If unsure, say M. diff --git a/drivers/virtio/virtio.c b/drivers/virtio/virtio.c index 22f15f444f75..75c8d560bbd3 100644 --- a/drivers/virtio/virtio.c +++ b/drivers/virtio/virtio.c @@ -526,8 +526,9 @@ int virtio_device_restore(struct virtio_device *dev) goto err; } - /* Finally, tell the device we're all set */ - virtio_add_status(dev, VIRTIO_CONFIG_S_DRIVER_OK); + /* If restore didn't do it, mark device DRIVER_OK ourselves. */ + if (!(dev->config->get_status(dev) & VIRTIO_CONFIG_S_DRIVER_OK)) + virtio_device_ready(dev); virtio_config_enable(dev); diff --git a/drivers/virtio/virtio_pci_common.c b/drivers/virtio/virtio_pci_common.c index fdbde1db5ec5..d724f676608b 100644 --- a/drivers/virtio/virtio_pci_common.c +++ b/drivers/virtio/virtio_pci_common.c @@ -24,46 +24,17 @@ MODULE_PARM_DESC(force_legacy, "Force legacy mode for transitional virtio 1 devices"); #endif -/* disable irq handlers */ -void vp_disable_cbs(struct virtio_device *vdev) +/* wait for pending irq handlers */ +void vp_synchronize_vectors(struct virtio_device *vdev) { struct virtio_pci_device *vp_dev = to_vp_device(vdev); int i; - if (vp_dev->intx_enabled) { - /* - * The below synchronize() guarantees that any - * interrupt for this line arriving after - * synchronize_irq() has completed is guaranteed to see - * intx_soft_enabled == false. - */ - WRITE_ONCE(vp_dev->intx_soft_enabled, false); + if (vp_dev->intx_enabled) synchronize_irq(vp_dev->pci_dev->irq); - } - - for (i = 0; i < vp_dev->msix_vectors; ++i) - disable_irq(pci_irq_vector(vp_dev->pci_dev, i)); -} - -/* enable irq handlers */ -void vp_enable_cbs(struct virtio_device *vdev) -{ - struct virtio_pci_device *vp_dev = to_vp_device(vdev); - int i; - - if (vp_dev->intx_enabled) { - disable_irq(vp_dev->pci_dev->irq); - /* - * The above disable_irq() provides TSO ordering and - * as such promotes the below store to store-release. - */ - WRITE_ONCE(vp_dev->intx_soft_enabled, true); - enable_irq(vp_dev->pci_dev->irq); - return; - } for (i = 0; i < vp_dev->msix_vectors; ++i) - enable_irq(pci_irq_vector(vp_dev->pci_dev, i)); + synchronize_irq(pci_irq_vector(vp_dev->pci_dev, i)); } /* the notify function used when creating a virt queue */ @@ -113,9 +84,6 @@ static irqreturn_t vp_interrupt(int irq, void *opaque) struct virtio_pci_device *vp_dev = opaque; u8 isr; - if (!READ_ONCE(vp_dev->intx_soft_enabled)) - return IRQ_NONE; - /* reading the ISR has the effect of also clearing it so it's very * important to save off the value. */ isr = ioread8(vp_dev->isr); @@ -173,8 +141,7 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors, snprintf(vp_dev->msix_names[v], sizeof *vp_dev->msix_names, "%s-config", name); err = request_irq(pci_irq_vector(vp_dev->pci_dev, v), - vp_config_changed, IRQF_NO_AUTOEN, - vp_dev->msix_names[v], + vp_config_changed, 0, vp_dev->msix_names[v], vp_dev); if (err) goto error; @@ -193,8 +160,7 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors, snprintf(vp_dev->msix_names[v], sizeof *vp_dev->msix_names, "%s-virtqueues", name); err = request_irq(pci_irq_vector(vp_dev->pci_dev, v), - vp_vring_interrupt, IRQF_NO_AUTOEN, - vp_dev->msix_names[v], + vp_vring_interrupt, 0, vp_dev->msix_names[v], vp_dev); if (err) goto error; @@ -371,7 +337,7 @@ static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs, "%s-%s", dev_name(&vp_dev->vdev.dev), names[i]); err = request_irq(pci_irq_vector(vp_dev->pci_dev, msix_vec), - vring_interrupt, IRQF_NO_AUTOEN, + vring_interrupt, 0, vp_dev->msix_names[msix_vec], vqs[i]); if (err) diff --git a/drivers/virtio/virtio_pci_common.h b/drivers/virtio/virtio_pci_common.h index 23f6c5c678d5..eb17a29fc7ef 100644 --- a/drivers/virtio/virtio_pci_common.h +++ b/drivers/virtio/virtio_pci_common.h @@ -63,7 +63,6 @@ struct virtio_pci_device { /* MSI-X support */ int msix_enabled; int intx_enabled; - bool intx_soft_enabled; cpumask_var_t *msix_affinity_masks; /* Name strings for interrupts. This size should be enough, * and I'm too lazy to allocate each name separately. */ @@ -102,10 +101,8 @@ static struct virtio_pci_device *to_vp_device(struct virtio_device *vdev) return container_of(vdev, struct virtio_pci_device, vdev); } -/* disable irq handlers */ -void vp_disable_cbs(struct virtio_device *vdev); -/* enable irq handlers */ -void vp_enable_cbs(struct virtio_device *vdev); +/* wait for pending irq handlers */ +void vp_synchronize_vectors(struct virtio_device *vdev); /* the notify function used when creating a virt queue */ bool vp_notify(struct virtqueue *vq); /* the config->del_vqs() implementation */ diff --git a/drivers/virtio/virtio_pci_legacy.c b/drivers/virtio/virtio_pci_legacy.c index 34141b9abe27..6f4e34ce96b8 100644 --- a/drivers/virtio/virtio_pci_legacy.c +++ b/drivers/virtio/virtio_pci_legacy.c @@ -98,8 +98,8 @@ static void vp_reset(struct virtio_device *vdev) /* Flush out the status write, and flush in device writes, * including MSi-X interrupts, if any. */ vp_legacy_get_status(&vp_dev->ldev); - /* Disable VQ/configuration callbacks. */ - vp_disable_cbs(vdev); + /* Flush pending VQ/configuration callbacks. */ + vp_synchronize_vectors(vdev); } static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector) @@ -185,7 +185,6 @@ static void del_vq(struct virtio_pci_vq_info *info) } static const struct virtio_config_ops virtio_pci_config_ops = { - .enable_cbs = vp_enable_cbs, .get = vp_get, .set = vp_set, .get_status = vp_get_status, diff --git a/drivers/virtio/virtio_pci_modern.c b/drivers/virtio/virtio_pci_modern.c index 5455bc041fb6..a2671a20ef77 100644 --- a/drivers/virtio/virtio_pci_modern.c +++ b/drivers/virtio/virtio_pci_modern.c @@ -172,8 +172,8 @@ static void vp_reset(struct virtio_device *vdev) */ while (vp_modern_get_status(mdev)) msleep(1); - /* Disable VQ/configuration callbacks. */ - vp_disable_cbs(vdev); + /* Flush pending VQ/configuration callbacks. */ + vp_synchronize_vectors(vdev); } static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector) @@ -293,7 +293,7 @@ static int virtio_pci_find_shm_cap(struct pci_dev *dev, u8 required_id, for (pos = pci_find_capability(dev, PCI_CAP_ID_VNDR); pos > 0; pos = pci_find_next_capability(dev, pos, PCI_CAP_ID_VNDR)) { - u8 type, cap_len, id; + u8 type, cap_len, id, res_bar; u32 tmp32; u64 res_offset, res_length; @@ -315,9 +315,14 @@ static int virtio_pci_find_shm_cap(struct pci_dev *dev, u8 required_id, if (id != required_id) continue; - /* Type, and ID match, looks good */ pci_read_config_byte(dev, pos + offsetof(struct virtio_pci_cap, - bar), bar); + bar), &res_bar); + if (res_bar >= PCI_STD_NUM_BARS) + continue; + + /* Type and ID match, and the BAR value isn't reserved. + * Looks good. + */ /* Read the lower 32bit of length and offset */ pci_read_config_dword(dev, pos + offsetof(struct virtio_pci_cap, @@ -337,6 +342,7 @@ static int virtio_pci_find_shm_cap(struct pci_dev *dev, u8 required_id, length_hi), &tmp32); res_length |= ((u64)tmp32) << 32; + *bar = res_bar; *offset = res_offset; *len = res_length; @@ -380,7 +386,6 @@ static bool vp_get_shm_region(struct virtio_device *vdev, } static const struct virtio_config_ops virtio_pci_config_nodev_ops = { - .enable_cbs = vp_enable_cbs, .get = NULL, .set = NULL, .generation = vp_generation, @@ -398,7 +403,6 @@ static const struct virtio_config_ops virtio_pci_config_nodev_ops = { }; static const struct virtio_config_ops virtio_pci_config_ops = { - .enable_cbs = vp_enable_cbs, .get = vp_get, .set = vp_set, .generation = vp_generation, diff --git a/drivers/virtio/virtio_pci_modern_dev.c b/drivers/virtio/virtio_pci_modern_dev.c index e8b3ff2b9fbc..591738ad3d56 100644 --- a/drivers/virtio/virtio_pci_modern_dev.c +++ b/drivers/virtio/virtio_pci_modern_dev.c @@ -35,6 +35,13 @@ vp_modern_map_capability(struct virtio_pci_modern_device *mdev, int off, pci_read_config_dword(dev, off + offsetof(struct virtio_pci_cap, length), &length); + /* Check if the BAR may have changed since we requested the region. */ + if (bar >= PCI_STD_NUM_BARS || !(mdev->modern_bars & (1 << bar))) { + dev_err(&dev->dev, + "virtio_pci: bar unexpectedly changed to %u\n", bar); + return NULL; + } + if (length <= start) { dev_err(&dev->dev, "virtio_pci: bad capability len %u (>%u expected)\n", @@ -120,7 +127,7 @@ static inline int virtio_pci_find_capability(struct pci_dev *dev, u8 cfg_type, &bar); /* Ignore structures with reserved BAR values */ - if (bar > 0x5) + if (bar >= PCI_STD_NUM_BARS) continue; if (type == cfg_type) { diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c index 962f1477b1fa..cfb028ca238e 100644 --- a/drivers/virtio/virtio_ring.c +++ b/drivers/virtio/virtio_ring.c @@ -379,19 +379,11 @@ static void vring_unmap_one_split_indirect(const struct vring_virtqueue *vq, flags = virtio16_to_cpu(vq->vq.vdev, desc->flags); - if (flags & VRING_DESC_F_INDIRECT) { - dma_unmap_single(vring_dma_dev(vq), - virtio64_to_cpu(vq->vq.vdev, desc->addr), - virtio32_to_cpu(vq->vq.vdev, desc->len), - (flags & VRING_DESC_F_WRITE) ? - DMA_FROM_DEVICE : DMA_TO_DEVICE); - } else { - dma_unmap_page(vring_dma_dev(vq), - virtio64_to_cpu(vq->vq.vdev, desc->addr), - virtio32_to_cpu(vq->vq.vdev, desc->len), - (flags & VRING_DESC_F_WRITE) ? - DMA_FROM_DEVICE : DMA_TO_DEVICE); - } + dma_unmap_page(vring_dma_dev(vq), + virtio64_to_cpu(vq->vq.vdev, desc->addr), + virtio32_to_cpu(vq->vq.vdev, desc->len), + (flags & VRING_DESC_F_WRITE) ? + DMA_FROM_DEVICE : DMA_TO_DEVICE); } static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq, @@ -984,24 +976,24 @@ static struct virtqueue *vring_create_virtqueue_split( * Packed ring specific functions - *_packed(). */ -static void vring_unmap_state_packed(const struct vring_virtqueue *vq, - struct vring_desc_extra *state) +static void vring_unmap_extra_packed(const struct vring_virtqueue *vq, + struct vring_desc_extra *extra) { u16 flags; if (!vq->use_dma_api) return; - flags = state->flags; + flags = extra->flags; if (flags & VRING_DESC_F_INDIRECT) { dma_unmap_single(vring_dma_dev(vq), - state->addr, state->len, + extra->addr, extra->len, (flags & VRING_DESC_F_WRITE) ? DMA_FROM_DEVICE : DMA_TO_DEVICE); } else { dma_unmap_page(vring_dma_dev(vq), - state->addr, state->len, + extra->addr, extra->len, (flags & VRING_DESC_F_WRITE) ? DMA_FROM_DEVICE : DMA_TO_DEVICE); } @@ -1017,19 +1009,11 @@ static void vring_unmap_desc_packed(const struct vring_virtqueue *vq, flags = le16_to_cpu(desc->flags); - if (flags & VRING_DESC_F_INDIRECT) { - dma_unmap_single(vring_dma_dev(vq), - le64_to_cpu(desc->addr), - le32_to_cpu(desc->len), - (flags & VRING_DESC_F_WRITE) ? - DMA_FROM_DEVICE : DMA_TO_DEVICE); - } else { - dma_unmap_page(vring_dma_dev(vq), - le64_to_cpu(desc->addr), - le32_to_cpu(desc->len), - (flags & VRING_DESC_F_WRITE) ? - DMA_FROM_DEVICE : DMA_TO_DEVICE); - } + dma_unmap_page(vring_dma_dev(vq), + le64_to_cpu(desc->addr), + le32_to_cpu(desc->len), + (flags & VRING_DESC_F_WRITE) ? + DMA_FROM_DEVICE : DMA_TO_DEVICE); } static struct vring_packed_desc *alloc_indirect_packed(unsigned int total_sg, @@ -1303,8 +1287,7 @@ unmap_release: for (n = 0; n < total_sg; n++) { if (i == err_idx) break; - vring_unmap_state_packed(vq, - &vq->packed.desc_extra[curr]); + vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr]); curr = vq->packed.desc_extra[curr].next; i++; if (i >= vq->packed.vring.num) @@ -1383,8 +1366,8 @@ static void detach_buf_packed(struct vring_virtqueue *vq, if (unlikely(vq->use_dma_api)) { curr = id; for (i = 0; i < state->num; i++) { - vring_unmap_state_packed(vq, - &vq->packed.desc_extra[curr]); + vring_unmap_extra_packed(vq, + &vq->packed.desc_extra[curr]); curr = vq->packed.desc_extra[curr].next; } } |