diff options
| -rw-r--r-- | drivers/infiniband/hw/hns/hns_roce_srq.c | 35 |
1 files changed, 16 insertions, 19 deletions
diff --git a/drivers/infiniband/hw/hns/hns_roce_srq.c b/drivers/infiniband/hw/hns/hns_roce_srq.c index 601f8cdfce96..cb848e8e6bbd 100644 --- a/drivers/infiniband/hw/hns/hns_roce_srq.c +++ b/drivers/infiniband/hw/hns/hns_roce_srq.c @@ -340,22 +340,16 @@ static int set_srq_param(struct hns_roce_srq *srq, } static int alloc_srq_buf(struct hns_roce_dev *hr_dev, struct hns_roce_srq *srq, - struct ib_udata *udata) + struct ib_udata *udata, + struct hns_roce_ib_create_srq *ucmd) { - struct hns_roce_ib_create_srq ucmd = {}; int ret; - if (udata) { - ret = ib_copy_validate_udata_in(udata, ucmd, que_addr); - if (ret) - return ret; - } - - ret = alloc_srq_idx(hr_dev, srq, udata, ucmd.que_addr); + ret = alloc_srq_idx(hr_dev, srq, udata, ucmd->que_addr); if (ret) return ret; - ret = alloc_srq_wqe_buf(hr_dev, srq, udata, ucmd.buf_addr); + ret = alloc_srq_wqe_buf(hr_dev, srq, udata, ucmd->buf_addr); if (ret) goto err_idx; @@ -404,22 +398,18 @@ static void free_srq_db(struct hns_roce_dev *hr_dev, struct hns_roce_srq *srq, static int alloc_srq_db(struct hns_roce_dev *hr_dev, struct hns_roce_srq *srq, struct ib_udata *udata, + struct hns_roce_ib_create_srq *ucmd, struct hns_roce_ib_create_srq_resp *resp) { - struct hns_roce_ib_create_srq ucmd; struct hns_roce_ucontext *uctx; int ret; if (udata) { - ret = ib_copy_validate_udata_in(udata, ucmd, que_addr); - if (ret) - return ret; - if ((hr_dev->caps.flags & HNS_ROCE_CAP_FLAG_SRQ_RECORD_DB) && - (ucmd.req_cap_flags & HNS_ROCE_SRQ_CAP_RECORD_DB)) { + (ucmd->req_cap_flags & HNS_ROCE_SRQ_CAP_RECORD_DB)) { uctx = rdma_udata_to_drv_context(udata, struct hns_roce_ucontext, ibucontext); - ret = hns_roce_db_map_user(uctx, ucmd.db_addr, + ret = hns_roce_db_map_user(uctx, ucmd->db_addr, &srq->rdb); if (ret) return ret; @@ -448,6 +438,7 @@ int hns_roce_create_srq(struct ib_srq *ib_srq, struct hns_roce_dev *hr_dev = to_hr_dev(ib_srq->device); struct hns_roce_ib_create_srq_resp resp = {}; struct hns_roce_srq *srq = to_hr_srq(ib_srq); + struct hns_roce_ib_create_srq ucmd = {}; int ret; mutex_init(&srq->mutex); @@ -457,11 +448,17 @@ int hns_roce_create_srq(struct ib_srq *ib_srq, if (ret) goto err_out; - ret = alloc_srq_buf(hr_dev, srq, udata); + if (udata) { + ret = ib_copy_validate_udata_in(udata, ucmd, que_addr); + if (ret) + goto err_out; + } + + ret = alloc_srq_buf(hr_dev, srq, udata, &ucmd); if (ret) goto err_out; - ret = alloc_srq_db(hr_dev, srq, udata, &resp); + ret = alloc_srq_db(hr_dev, srq, udata, &ucmd, &resp); if (ret) goto err_srq_buf; |
