diff options
Diffstat (limited to 'drivers')
-rw-r--r-- | drivers/infiniband/core/cm.c | 91 | ||||
-rw-r--r-- | drivers/infiniband/core/ucm.c | 19 |
2 files changed, 97 insertions, 13 deletions
diff --git a/drivers/infiniband/core/cm.c b/drivers/infiniband/core/cm.c index 86fee43502cd..490fd03766db 100644 --- a/drivers/infiniband/core/cm.c +++ b/drivers/infiniband/core/cm.c @@ -32,7 +32,7 @@ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * - * $Id: cm.c 2821 2005-07-08 17:07:28Z sean.hefty $ + * $Id: cm.c 4311 2005-12-05 18:42:01Z sean.hefty $ */ #include <linux/completion.h> @@ -132,6 +132,7 @@ struct cm_id_private { /* todo: use alternate port on send failure */ struct cm_av av; struct cm_av alt_av; + struct ib_cm_compare_data *compare_data; void *private_data; __be64 tid; @@ -357,6 +358,41 @@ static struct cm_id_private * cm_acquire_id(__be32 local_id, __be32 remote_id) return cm_id_priv; } +static void cm_mask_copy(u8 *dst, u8 *src, u8 *mask) +{ + int i; + + for (i = 0; i < IB_CM_COMPARE_SIZE / sizeof(unsigned long); i++) + ((unsigned long *) dst)[i] = ((unsigned long *) src)[i] & + ((unsigned long *) mask)[i]; +} + +static int cm_compare_data(struct ib_cm_compare_data *src_data, + struct ib_cm_compare_data *dst_data) +{ + u8 src[IB_CM_COMPARE_SIZE]; + u8 dst[IB_CM_COMPARE_SIZE]; + + if (!src_data || !dst_data) + return 0; + + cm_mask_copy(src, src_data->data, dst_data->mask); + cm_mask_copy(dst, dst_data->data, src_data->mask); + return memcmp(src, dst, IB_CM_COMPARE_SIZE); +} + +static int cm_compare_private_data(u8 *private_data, + struct ib_cm_compare_data *dst_data) +{ + u8 src[IB_CM_COMPARE_SIZE]; + + if (!dst_data) + return 0; + + cm_mask_copy(src, private_data, dst_data->mask); + return memcmp(src, dst_data->data, IB_CM_COMPARE_SIZE); +} + static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv) { struct rb_node **link = &cm.listen_service_table.rb_node; @@ -364,14 +400,18 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv) struct cm_id_private *cur_cm_id_priv; __be64 service_id = cm_id_priv->id.service_id; __be64 service_mask = cm_id_priv->id.service_mask; + int data_cmp; while (*link) { parent = *link; cur_cm_id_priv = rb_entry(parent, struct cm_id_private, service_node); + data_cmp = cm_compare_data(cm_id_priv->compare_data, + cur_cm_id_priv->compare_data); if ((cur_cm_id_priv->id.service_mask & service_id) == (service_mask & cur_cm_id_priv->id.service_id) && - (cm_id_priv->id.device == cur_cm_id_priv->id.device)) + (cm_id_priv->id.device == cur_cm_id_priv->id.device) && + !data_cmp) return cur_cm_id_priv; if (cm_id_priv->id.device < cur_cm_id_priv->id.device) @@ -380,6 +420,10 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv) link = &(*link)->rb_right; else if (service_id < cur_cm_id_priv->id.service_id) link = &(*link)->rb_left; + else if (service_id > cur_cm_id_priv->id.service_id) + link = &(*link)->rb_right; + else if (data_cmp < 0) + link = &(*link)->rb_left; else link = &(*link)->rb_right; } @@ -389,16 +433,20 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv) } static struct cm_id_private * cm_find_listen(struct ib_device *device, - __be64 service_id) + __be64 service_id, + u8 *private_data) { struct rb_node *node = cm.listen_service_table.rb_node; struct cm_id_private *cm_id_priv; + int data_cmp; while (node) { cm_id_priv = rb_entry(node, struct cm_id_private, service_node); + data_cmp = cm_compare_private_data(private_data, + cm_id_priv->compare_data); if ((cm_id_priv->id.service_mask & service_id) == cm_id_priv->id.service_id && - (cm_id_priv->id.device == device)) + (cm_id_priv->id.device == device) && !data_cmp) return cm_id_priv; if (device < cm_id_priv->id.device) @@ -407,6 +455,10 @@ static struct cm_id_private * cm_find_listen(struct ib_device *device, node = node->rb_right; else if (service_id < cm_id_priv->id.service_id) node = node->rb_left; + else if (service_id > cm_id_priv->id.service_id) + node = node->rb_right; + else if (data_cmp < 0) + node = node->rb_left; else node = node->rb_right; } @@ -730,15 +782,14 @@ retest: wait_for_completion(&cm_id_priv->comp); while ((work = cm_dequeue_work(cm_id_priv)) != NULL) cm_free_work(work); - if (cm_id_priv->private_data && cm_id_priv->private_data_len) - kfree(cm_id_priv->private_data); + kfree(cm_id_priv->compare_data); + kfree(cm_id_priv->private_data); kfree(cm_id_priv); } EXPORT_SYMBOL(ib_destroy_cm_id); -int ib_cm_listen(struct ib_cm_id *cm_id, - __be64 service_id, - __be64 service_mask) +int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask, + struct ib_cm_compare_data *compare_data) { struct cm_id_private *cm_id_priv, *cur_cm_id_priv; unsigned long flags; @@ -752,7 +803,19 @@ int ib_cm_listen(struct ib_cm_id *cm_id, return -EINVAL; cm_id_priv = container_of(cm_id, struct cm_id_private, id); - BUG_ON(cm_id->state != IB_CM_IDLE); + if (cm_id->state != IB_CM_IDLE) + return -EINVAL; + + if (compare_data) { + cm_id_priv->compare_data = kzalloc(sizeof *compare_data, + GFP_KERNEL); + if (!cm_id_priv->compare_data) + return -ENOMEM; + cm_mask_copy(cm_id_priv->compare_data->data, + compare_data->data, compare_data->mask); + memcpy(cm_id_priv->compare_data->mask, compare_data->mask, + IB_CM_COMPARE_SIZE); + } cm_id->state = IB_CM_LISTEN; @@ -769,6 +832,8 @@ int ib_cm_listen(struct ib_cm_id *cm_id, if (cur_cm_id_priv) { cm_id->state = IB_CM_IDLE; + kfree(cm_id_priv->compare_data); + cm_id_priv->compare_data = NULL; ret = -EBUSY; } return ret; @@ -1241,7 +1306,8 @@ static struct cm_id_private * cm_match_req(struct cm_work *work, /* Find matching listen request. */ listen_cm_id_priv = cm_find_listen(cm_id_priv->id.device, - req_msg->service_id); + req_msg->service_id, + req_msg->private_data); if (!listen_cm_id_priv) { spin_unlock_irqrestore(&cm.lock, flags); cm_issue_rej(work->port, work->mad_recv_wc, @@ -2654,7 +2720,8 @@ static int cm_sidr_req_handler(struct cm_work *work) goto out; /* Duplicate message. */ } cur_cm_id_priv = cm_find_listen(cm_id->device, - sidr_req_msg->service_id); + sidr_req_msg->service_id, + sidr_req_msg->private_data); if (!cur_cm_id_priv) { rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table); spin_unlock_irqrestore(&cm.lock, flags); diff --git a/drivers/infiniband/core/ucm.c b/drivers/infiniband/core/ucm.c index b396bf703f80..0136aee0faa7 100644 --- a/drivers/infiniband/core/ucm.c +++ b/drivers/infiniband/core/ucm.c @@ -648,6 +648,17 @@ out: return result; } +static int ucm_validate_listen(__be64 service_id, __be64 service_mask) +{ + service_id &= service_mask; + + if (((service_id & IB_CMA_SERVICE_ID_MASK) == IB_CMA_SERVICE_ID) || + ((service_id & IB_SDP_SERVICE_ID_MASK) == IB_SDP_SERVICE_ID)) + return -EINVAL; + + return 0; +} + static ssize_t ib_ucm_listen(struct ib_ucm_file *file, const char __user *inbuf, int in_len, int out_len) @@ -663,7 +674,13 @@ static ssize_t ib_ucm_listen(struct ib_ucm_file *file, if (IS_ERR(ctx)) return PTR_ERR(ctx); - result = ib_cm_listen(ctx->cm_id, cmd.service_id, cmd.service_mask); + result = ucm_validate_listen(cmd.service_id, cmd.service_mask); + if (result) + goto out; + + result = ib_cm_listen(ctx->cm_id, cmd.service_id, cmd.service_mask, + NULL); +out: ib_ucm_ctx_put(ctx); return result; } |