summaryrefslogtreecommitdiff
path: root/drivers/infiniband/sw/rdmavt/qp.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/infiniband/sw/rdmavt/qp.c')
-rw-r--r--drivers/infiniband/sw/rdmavt/qp.c152
1 files changed, 111 insertions, 41 deletions
diff --git a/drivers/infiniband/sw/rdmavt/qp.c b/drivers/infiniband/sw/rdmavt/qp.c
index 0d804a58f954..1384060f175d 100644
--- a/drivers/infiniband/sw/rdmavt/qp.c
+++ b/drivers/infiniband/sw/rdmavt/qp.c
@@ -803,6 +803,46 @@ static void rvt_remove_qp(struct rvt_dev_info *rdi, struct rvt_qp *qp)
}
/**
+ * rvt_alloc_rq - allocate memory for user or kernel buffer
+ * @rq: receive queue data structure
+ * @size: number of request queue entries
+ * @node: The NUMA node
+ * @udata: True if user data is available or not false
+ *
+ * Return: If memory allocation failed, return -ENONEM
+ * This function is used by both shared receive
+ * queues and non-shared receive queues to allocate
+ * memory.
+ */
+int rvt_alloc_rq(struct rvt_rq *rq, u32 size, int node,
+ struct ib_udata *udata)
+{
+ if (udata) {
+ rq->wq = vmalloc_user(sizeof(struct rvt_rwq) + size);
+ if (!rq->wq)
+ goto bail;
+ /* need kwq with no buffers */
+ rq->kwq = kzalloc_node(sizeof(*rq->kwq), GFP_KERNEL, node);
+ if (!rq->kwq)
+ goto bail;
+ rq->kwq->curr_wq = rq->wq->wq;
+ } else {
+ /* need kwq with buffers */
+ rq->kwq =
+ vzalloc_node(sizeof(struct rvt_krwq) + size, node);
+ if (!rq->kwq)
+ goto bail;
+ rq->kwq->curr_wq = rq->kwq->wq;
+ }
+
+ spin_lock_init(&rq->lock);
+ return 0;
+bail:
+ rvt_free_rq(rq);
+ return -ENOMEM;
+}
+
+/**
* rvt_init_qp - initialize the QP state to the reset state
* @qp: the QP to init or reinit
* @type: the QP type
@@ -852,10 +892,6 @@ static void rvt_init_qp(struct rvt_dev_info *rdi, struct rvt_qp *qp,
qp->s_tail_ack_queue = 0;
qp->s_acked_ack_queue = 0;
qp->s_num_rd_atomic = 0;
- if (qp->r_rq.wq) {
- qp->r_rq.wq->head = 0;
- qp->r_rq.wq->tail = 0;
- }
qp->r_sge.num_sge = 0;
atomic_set(&qp->s_reserved_used, 0);
}
@@ -1046,17 +1082,12 @@ struct ib_qp *rvt_create_qp(struct ib_pd *ibpd,
qp->r_rq.max_sge = init_attr->cap.max_recv_sge;
sz = (sizeof(struct ib_sge) * qp->r_rq.max_sge) +
sizeof(struct rvt_rwqe);
- if (udata)
- qp->r_rq.wq = vmalloc_user(
- sizeof(struct rvt_rwq) +
- qp->r_rq.size * sz);
- else
- qp->r_rq.wq = vzalloc_node(
- sizeof(struct rvt_rwq) +
- qp->r_rq.size * sz,
- rdi->dparms.node);
- if (!qp->r_rq.wq)
+ err = rvt_alloc_rq(&qp->r_rq, qp->r_rq.size * sz,
+ rdi->dparms.node, udata);
+ if (err) {
+ ret = ERR_PTR(err);
goto bail_driver_priv;
+ }
}
/*
@@ -1202,8 +1233,7 @@ bail_qpn:
rvt_free_qpn(&rdi->qp_dev->qpn_table, qp->ibqp.qp_num);
bail_rq_wq:
- if (!qp->ip)
- vfree(qp->r_rq.wq);
+ rvt_free_rq(&qp->r_rq);
bail_driver_priv:
rdi->driver_f.qp_priv_free(rdi, qp);
@@ -1269,19 +1299,26 @@ int rvt_error_qp(struct rvt_qp *qp, enum ib_wc_status err)
}
wc.status = IB_WC_WR_FLUSH_ERR;
- if (qp->r_rq.wq) {
- struct rvt_rwq *wq;
+ if (qp->r_rq.kwq) {
u32 head;
u32 tail;
+ struct rvt_rwq *wq = NULL;
+ struct rvt_krwq *kwq = NULL;
spin_lock(&qp->r_rq.lock);
-
+ /* qp->ip used to validate if there is a user buffer mmaped */
+ if (qp->ip) {
+ wq = qp->r_rq.wq;
+ head = RDMA_READ_UAPI_ATOMIC(wq->head);
+ tail = RDMA_READ_UAPI_ATOMIC(wq->tail);
+ } else {
+ kwq = qp->r_rq.kwq;
+ head = kwq->head;
+ tail = kwq->tail;
+ }
/* sanity check pointers before trusting them */
- wq = qp->r_rq.wq;
- head = wq->head;
if (head >= qp->r_rq.size)
head = 0;
- tail = wq->tail;
if (tail >= qp->r_rq.size)
tail = 0;
while (tail != head) {
@@ -1290,8 +1327,10 @@ int rvt_error_qp(struct rvt_qp *qp, enum ib_wc_status err)
tail = 0;
rvt_cq_enter(ibcq_to_rvtcq(qp->ibqp.recv_cq), &wc, 1);
}
- wq->tail = tail;
-
+ if (qp->ip)
+ RDMA_WRITE_UAPI_ATOMIC(wq->tail, tail);
+ else
+ kwq->tail = tail;
spin_unlock(&qp->r_rq.lock);
} else if (qp->ibqp.event_handler) {
ret = 1;
@@ -1634,8 +1673,7 @@ int rvt_destroy_qp(struct ib_qp *ibqp, struct ib_udata *udata)
if (qp->ip)
kref_put(&qp->ip->ref, rvt_release_mmap_info);
- else
- vfree(qp->r_rq.wq);
+ kvfree(qp->r_rq.kwq);
rdi->driver_f.qp_priv_free(rdi, qp);
kfree(qp->s_ack_queue);
rdma_destroy_ah_attr(&qp->remote_ah_attr);
@@ -1721,7 +1759,7 @@ int rvt_post_recv(struct ib_qp *ibqp, const struct ib_recv_wr *wr,
const struct ib_recv_wr **bad_wr)
{
struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
- struct rvt_rwq *wq = qp->r_rq.wq;
+ struct rvt_krwq *wq = qp->r_rq.kwq;
unsigned long flags;
int qp_err_flush = (ib_rvt_state_ops[qp->state] & RVT_FLUSH_RECV) &&
!qp->ibqp.srq;
@@ -1746,7 +1784,7 @@ int rvt_post_recv(struct ib_qp *ibqp, const struct ib_recv_wr *wr,
next = wq->head + 1;
if (next >= qp->r_rq.size)
next = 0;
- if (next == wq->tail) {
+ if (next == READ_ONCE(wq->tail)) {
spin_unlock_irqrestore(&qp->r_rq.lock, flags);
*bad_wr = wr;
return -ENOMEM;
@@ -1770,8 +1808,7 @@ int rvt_post_recv(struct ib_qp *ibqp, const struct ib_recv_wr *wr,
* Make sure queue entry is written
* before the head index.
*/
- smp_wmb();
- wq->head = next;
+ smp_store_release(&wq->head, next);
}
spin_unlock_irqrestore(&qp->r_rq.lock, flags);
}
@@ -2141,7 +2178,7 @@ int rvt_post_srq_recv(struct ib_srq *ibsrq, const struct ib_recv_wr *wr,
const struct ib_recv_wr **bad_wr)
{
struct rvt_srq *srq = ibsrq_to_rvtsrq(ibsrq);
- struct rvt_rwq *wq;
+ struct rvt_krwq *wq;
unsigned long flags;
for (; wr; wr = wr->next) {
@@ -2155,11 +2192,11 @@ int rvt_post_srq_recv(struct ib_srq *ibsrq, const struct ib_recv_wr *wr,
}
spin_lock_irqsave(&srq->rq.lock, flags);
- wq = srq->rq.wq;
+ wq = srq->rq.kwq;
next = wq->head + 1;
if (next >= srq->rq.size)
next = 0;
- if (next == wq->tail) {
+ if (next == READ_ONCE(wq->tail)) {
spin_unlock_irqrestore(&srq->rq.lock, flags);
*bad_wr = wr;
return -ENOMEM;
@@ -2171,8 +2208,7 @@ int rvt_post_srq_recv(struct ib_srq *ibsrq, const struct ib_recv_wr *wr,
for (i = 0; i < wr->num_sge; i++)
wqe->sg_list[i] = wr->sg_list[i];
/* Make sure queue entry is written before the head index. */
- smp_wmb();
- wq->head = next;
+ smp_store_release(&wq->head, next);
spin_unlock_irqrestore(&srq->rq.lock, flags);
}
return 0;
@@ -2230,6 +2266,25 @@ bad_lkey:
}
/**
+ * get_rvt_head - get head indices of the circular buffer
+ * @rq: data structure for request queue entry
+ * @ip: the QP
+ *
+ * Return - head index value
+ */
+static inline u32 get_rvt_head(struct rvt_rq *rq, void *ip)
+{
+ u32 head;
+
+ if (ip)
+ head = RDMA_READ_UAPI_ATOMIC(rq->wq->head);
+ else
+ head = rq->kwq->head;
+
+ return head;
+}
+
+/**
* rvt_get_rwqe - copy the next RWQE into the QP's RWQE
* @qp: the QP
* @wr_id_only: update qp->r_wr_id only, not qp->r_sge
@@ -2243,21 +2298,26 @@ int rvt_get_rwqe(struct rvt_qp *qp, bool wr_id_only)
{
unsigned long flags;
struct rvt_rq *rq;
+ struct rvt_krwq *kwq;
struct rvt_rwq *wq;
struct rvt_srq *srq;
struct rvt_rwqe *wqe;
void (*handler)(struct ib_event *, void *);
u32 tail;
+ u32 head;
int ret;
+ void *ip = NULL;
if (qp->ibqp.srq) {
srq = ibsrq_to_rvtsrq(qp->ibqp.srq);
handler = srq->ibsrq.event_handler;
rq = &srq->rq;
+ ip = srq->ip;
} else {
srq = NULL;
handler = NULL;
rq = &qp->r_rq;
+ ip = qp->ip;
}
spin_lock_irqsave(&rq->lock, flags);
@@ -2265,17 +2325,24 @@ int rvt_get_rwqe(struct rvt_qp *qp, bool wr_id_only)
ret = 0;
goto unlock;
}
+ if (ip) {
+ wq = rq->wq;
+ tail = RDMA_READ_UAPI_ATOMIC(wq->tail);
+ } else {
+ kwq = rq->kwq;
+ tail = kwq->tail;
+ }
- wq = rq->wq;
- tail = wq->tail;
/* Validate tail before using it since it is user writable. */
if (tail >= rq->size)
tail = 0;
- if (unlikely(tail == wq->head)) {
+
+ head = get_rvt_head(rq, ip);
+ if (unlikely(tail == head)) {
ret = 0;
goto unlock;
}
- /* Make sure entry is read after head index is read. */
+ /* Make sure entry is read after the count is read. */
smp_rmb();
wqe = rvt_get_rwqe_ptr(rq, tail);
/*
@@ -2285,7 +2352,10 @@ int rvt_get_rwqe(struct rvt_qp *qp, bool wr_id_only)
*/
if (++tail >= rq->size)
tail = 0;
- wq->tail = tail;
+ if (ip)
+ RDMA_WRITE_UAPI_ATOMIC(wq->tail, tail);
+ else
+ kwq->tail = tail;
if (!wr_id_only && !init_sge(qp, wqe)) {
ret = -1;
goto unlock;
@@ -2301,7 +2371,7 @@ int rvt_get_rwqe(struct rvt_qp *qp, bool wr_id_only)
* Validate head pointer value and compute
* the number of remaining WQEs.
*/
- n = wq->head;
+ n = get_rvt_head(rq, ip);
if (n >= rq->size)
n = 0;
if (n < tail)