diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Kconfig.kasan | 22 | ||||
-rw-r--r-- | lib/assoc_array.c | 8 | ||||
-rw-r--r-- | lib/crc32.c | 4 | ||||
-rw-r--r-- | lib/int_sqrt.c | 2 | ||||
-rw-r--r-- | lib/objagg.c | 583 | ||||
-rw-r--r-- | lib/rhashtable.c | 2 | ||||
-rw-r--r-- | lib/sbitmap.c | 13 | ||||
-rw-r--r-- | lib/test_bpf.c | 2 | ||||
-rw-r--r-- | lib/test_kmod.c | 2 | ||||
-rw-r--r-- | lib/test_objagg.c | 199 | ||||
-rw-r--r-- | lib/test_rhashtable.c | 36 | ||||
-rw-r--r-- | lib/test_xarray.c | 57 | ||||
-rw-r--r-- | lib/xarray.c | 92 |
13 files changed, 916 insertions, 106 deletions
diff --git a/lib/Kconfig.kasan b/lib/Kconfig.kasan index d8c474b6691e..9737059ec58b 100644 --- a/lib/Kconfig.kasan +++ b/lib/Kconfig.kasan @@ -113,6 +113,28 @@ config KASAN_INLINE endchoice +config KASAN_STACK_ENABLE + bool "Enable stack instrumentation (unsafe)" if CC_IS_CLANG && !COMPILE_TEST + default !(CLANG_VERSION < 90000) + depends on KASAN + help + The LLVM stack address sanitizer has a know problem that + causes excessive stack usage in a lot of functions, see + https://bugs.llvm.org/show_bug.cgi?id=38809 + Disabling asan-stack makes it safe to run kernels build + with clang-8 with KASAN enabled, though it loses some of + the functionality. + This feature is always disabled when compile-testing with clang-8 + or earlier to avoid cluttering the output in stack overflow + warnings, but clang-8 users can still enable it for builds without + CONFIG_COMPILE_TEST. On gcc and later clang versions it is + assumed to always be safe to use and enabled by default. + +config KASAN_STACK + int + default 1 if KASAN_STACK_ENABLE || CC_IS_GCC + default 0 + config KASAN_S390_4_LEVEL_PAGING bool "KASan: use 4-level paging" depends on KASAN && S390 diff --git a/lib/assoc_array.c b/lib/assoc_array.c index c6659cb37033..59875eb278ea 100644 --- a/lib/assoc_array.c +++ b/lib/assoc_array.c @@ -768,9 +768,11 @@ all_leaves_cluster_together: new_s0->index_key[i] = ops->get_key_chunk(index_key, i * ASSOC_ARRAY_KEY_CHUNK_SIZE); - blank = ULONG_MAX << (level & ASSOC_ARRAY_KEY_CHUNK_MASK); - pr_devel("blank off [%zu] %d: %lx\n", keylen - 1, level, blank); - new_s0->index_key[keylen - 1] &= ~blank; + if (level & ASSOC_ARRAY_KEY_CHUNK_MASK) { + blank = ULONG_MAX << (level & ASSOC_ARRAY_KEY_CHUNK_MASK); + pr_devel("blank off [%zu] %d: %lx\n", keylen - 1, level, blank); + new_s0->index_key[keylen - 1] &= ~blank; + } /* This now reduces to a node splitting exercise for which we'll need * to regenerate the disparity table. diff --git a/lib/crc32.c b/lib/crc32.c index 45b1d67a1767..4a20455d1f61 100644 --- a/lib/crc32.c +++ b/lib/crc32.c @@ -206,8 +206,8 @@ u32 __pure __weak __crc32c_le(u32 crc, unsigned char const *p, size_t len) EXPORT_SYMBOL(crc32_le); EXPORT_SYMBOL(__crc32c_le); -u32 crc32_le_base(u32, unsigned char const *, size_t) __alias(crc32_le); -u32 __crc32c_le_base(u32, unsigned char const *, size_t) __alias(__crc32c_le); +u32 __pure crc32_le_base(u32, unsigned char const *, size_t) __alias(crc32_le); +u32 __pure __crc32c_le_base(u32, unsigned char const *, size_t) __alias(__crc32c_le); /* * This multiplies the polynomials x and y modulo the given modulus. diff --git a/lib/int_sqrt.c b/lib/int_sqrt.c index 14436f4ca6bd..30e0f9770f88 100644 --- a/lib/int_sqrt.c +++ b/lib/int_sqrt.c @@ -52,7 +52,7 @@ u32 int_sqrt64(u64 x) if (x <= ULONG_MAX) return int_sqrt((unsigned long) x); - m = 1ULL << (fls64(x) & ~1ULL); + m = 1ULL << ((fls64(x) - 1) & ~1ULL); while (m != 0) { b = y + m; y >>= 1; diff --git a/lib/objagg.c b/lib/objagg.c index c9b457a91153..576be22e86de 100644 --- a/lib/objagg.c +++ b/lib/objagg.c @@ -4,6 +4,7 @@ #include <linux/module.h> #include <linux/slab.h> #include <linux/rhashtable.h> +#include <linux/idr.h> #include <linux/list.h> #include <linux/sort.h> #include <linux/objagg.h> @@ -11,6 +12,34 @@ #define CREATE_TRACE_POINTS #include <trace/events/objagg.h> +struct objagg_hints { + struct rhashtable node_ht; + struct rhashtable_params ht_params; + struct list_head node_list; + unsigned int node_count; + unsigned int root_count; + unsigned int refcount; + const struct objagg_ops *ops; +}; + +struct objagg_hints_node { + struct rhash_head ht_node; /* member of objagg_hints->node_ht */ + struct list_head list; /* member of objagg_hints->node_list */ + struct objagg_hints_node *parent; + unsigned int root_id; + struct objagg_obj_stats_info stats_info; + unsigned long obj[0]; +}; + +static struct objagg_hints_node * +objagg_hints_lookup(struct objagg_hints *objagg_hints, void *obj) +{ + if (!objagg_hints) + return NULL; + return rhashtable_lookup_fast(&objagg_hints->node_ht, obj, + objagg_hints->ht_params); +} + struct objagg { const struct objagg_ops *ops; void *priv; @@ -18,6 +47,8 @@ struct objagg { struct rhashtable_params ht_params; struct list_head obj_list; unsigned int obj_count; + struct ida root_ida; + struct objagg_hints *hints; }; struct objagg_obj { @@ -30,6 +61,7 @@ struct objagg_obj { void *delta_priv; /* user delta private */ void *root_priv; /* user root private */ }; + unsigned int root_id; unsigned int refcount; /* counts number of users of this object * including nested objects */ @@ -130,7 +162,8 @@ static struct objagg_obj *objagg_obj_lookup(struct objagg *objagg, void *obj) static int objagg_obj_parent_assign(struct objagg *objagg, struct objagg_obj *objagg_obj, - struct objagg_obj *parent) + struct objagg_obj *parent, + bool take_parent_ref) { void *delta_priv; @@ -144,7 +177,8 @@ static int objagg_obj_parent_assign(struct objagg *objagg, */ objagg_obj->parent = parent; objagg_obj->delta_priv = delta_priv; - objagg_obj_ref_inc(objagg_obj->parent); + if (take_parent_ref) + objagg_obj_ref_inc(objagg_obj->parent); trace_objagg_obj_parent_assign(objagg, objagg_obj, parent, parent->refcount); @@ -164,7 +198,7 @@ static int objagg_obj_parent_lookup_assign(struct objagg *objagg, if (!objagg_obj_is_root(objagg_obj_cur)) continue; err = objagg_obj_parent_assign(objagg, objagg_obj, - objagg_obj_cur); + objagg_obj_cur, true); if (!err) return 0; } @@ -184,16 +218,68 @@ static void objagg_obj_parent_unassign(struct objagg *objagg, __objagg_obj_put(objagg, objagg_obj->parent); } +static int objagg_obj_root_id_alloc(struct objagg *objagg, + struct objagg_obj *objagg_obj, + struct objagg_hints_node *hnode) +{ + unsigned int min, max; + int root_id; + + /* In case there are no hints available, the root id is invalid. */ + if (!objagg->hints) { + objagg_obj->root_id = OBJAGG_OBJ_ROOT_ID_INVALID; + return 0; + } + + if (hnode) { + min = hnode->root_id; + max = hnode->root_id; + } else { + /* For objects with no hint, start after the last + * hinted root_id. + */ + min = objagg->hints->root_count; + max = ~0; + } + + root_id = ida_alloc_range(&objagg->root_ida, min, max, GFP_KERNEL); + + if (root_id < 0) + return root_id; + objagg_obj->root_id = root_id; + return 0; +} + +static void objagg_obj_root_id_free(struct objagg *objagg, + struct objagg_obj *objagg_obj) +{ + if (!objagg->hints) + return; + ida_free(&objagg->root_ida, objagg_obj->root_id); +} + static int objagg_obj_root_create(struct objagg *objagg, - struct objagg_obj *objagg_obj) + struct objagg_obj *objagg_obj, + struct objagg_hints_node *hnode) { - objagg_obj->root_priv = objagg->ops->root_create(objagg->priv, - objagg_obj->obj); - if (IS_ERR(objagg_obj->root_priv)) - return PTR_ERR(objagg_obj->root_priv); + int err; + err = objagg_obj_root_id_alloc(objagg, objagg_obj, hnode); + if (err) + return err; + objagg_obj->root_priv = objagg->ops->root_create(objagg->priv, + objagg_obj->obj, + objagg_obj->root_id); + if (IS_ERR(objagg_obj->root_priv)) { + err = PTR_ERR(objagg_obj->root_priv); + goto err_root_create; + } trace_objagg_obj_root_create(objagg, objagg_obj); return 0; + +err_root_create: + objagg_obj_root_id_free(objagg, objagg_obj); + return err; } static void objagg_obj_root_destroy(struct objagg *objagg, @@ -201,19 +287,69 @@ static void objagg_obj_root_destroy(struct objagg *objagg, { trace_objagg_obj_root_destroy(objagg, objagg_obj); objagg->ops->root_destroy(objagg->priv, objagg_obj->root_priv); + objagg_obj_root_id_free(objagg, objagg_obj); +} + +static struct objagg_obj *__objagg_obj_get(struct objagg *objagg, void *obj); + +static int objagg_obj_init_with_hints(struct objagg *objagg, + struct objagg_obj *objagg_obj, + bool *hint_found) +{ + struct objagg_hints_node *hnode; + struct objagg_obj *parent; + int err; + + hnode = objagg_hints_lookup(objagg->hints, objagg_obj->obj); + if (!hnode) { + *hint_found = false; + return 0; + } + *hint_found = true; + + if (!hnode->parent) + return objagg_obj_root_create(objagg, objagg_obj, hnode); + + parent = __objagg_obj_get(objagg, hnode->parent->obj); + if (IS_ERR(parent)) + return PTR_ERR(parent); + + err = objagg_obj_parent_assign(objagg, objagg_obj, parent, false); + if (err) { + *hint_found = false; + err = 0; + goto err_parent_assign; + } + + return 0; + +err_parent_assign: + objagg_obj_put(objagg, parent); + return err; } static int objagg_obj_init(struct objagg *objagg, struct objagg_obj *objagg_obj) { + bool hint_found; int err; + /* First, try to use hints if they are available and + * if they provide result. + */ + err = objagg_obj_init_with_hints(objagg, objagg_obj, &hint_found); + if (err) + return err; + + if (hint_found) + return 0; + /* Try to find if the object can be aggregated under an existing one. */ err = objagg_obj_parent_lookup_assign(objagg, objagg_obj); if (!err) return 0; /* If aggregation is not possible, make the object a root. */ - return objagg_obj_root_create(objagg, objagg_obj); + return objagg_obj_root_create(objagg, objagg_obj, NULL); } static void objagg_obj_fini(struct objagg *objagg, @@ -349,8 +485,9 @@ EXPORT_SYMBOL(objagg_obj_put); /** * objagg_create - creates a new objagg instance - * @ops: user-specific callbacks - * @priv: pointer to a private data passed to the ops + * @ops: user-specific callbacks + * @objagg_hints: hints, can be NULL + * @priv: pointer to a private data passed to the ops * * Note: all locking must be provided by the caller. * @@ -374,18 +511,25 @@ EXPORT_SYMBOL(objagg_obj_put); * Returns a pointer to newly created objagg instance in case of success, * otherwise it returns pointer error using ERR_PTR macro. */ -struct objagg *objagg_create(const struct objagg_ops *ops, void *priv) +struct objagg *objagg_create(const struct objagg_ops *ops, + struct objagg_hints *objagg_hints, void *priv) { struct objagg *objagg; int err; if (WARN_ON(!ops || !ops->root_create || !ops->root_destroy || - !ops->delta_create || !ops->delta_destroy)) + !ops->delta_check || !ops->delta_create || + !ops->delta_destroy)) return ERR_PTR(-EINVAL); + objagg = kzalloc(sizeof(*objagg), GFP_KERNEL); if (!objagg) return ERR_PTR(-ENOMEM); objagg->ops = ops; + if (objagg_hints) { + objagg->hints = objagg_hints; + objagg_hints->refcount++; + } objagg->priv = priv; INIT_LIST_HEAD(&objagg->obj_list); @@ -397,6 +541,8 @@ struct objagg *objagg_create(const struct objagg_ops *ops, void *priv) if (err) goto err_rhashtable_init; + ida_init(&objagg->root_ida); + trace_objagg_create(objagg); return objagg; @@ -415,8 +561,11 @@ EXPORT_SYMBOL(objagg_create); void objagg_destroy(struct objagg *objagg) { trace_objagg_destroy(objagg); + ida_destroy(&objagg->root_ida); WARN_ON(!list_empty(&objagg->obj_list)); rhashtable_destroy(&objagg->obj_ht); + if (objagg->hints) + objagg_hints_put(objagg->hints); kfree(objagg); } EXPORT_SYMBOL(objagg_destroy); @@ -472,6 +621,8 @@ const struct objagg_stats *objagg_stats_get(struct objagg *objagg) objagg_stats->stats_info[i].objagg_obj = objagg_obj; objagg_stats->stats_info[i].is_root = objagg_obj_is_root(objagg_obj); + if (objagg_stats->stats_info[i].is_root) + objagg_stats->root_count++; i++; } objagg_stats->stats_info_count = i; @@ -485,7 +636,7 @@ const struct objagg_stats *objagg_stats_get(struct objagg *objagg) EXPORT_SYMBOL(objagg_stats_get); /** - * objagg_stats_puts - puts stats of the objagg instance + * objagg_stats_put - puts stats of the objagg instance * @objagg_stats: objagg instance stats * * Note: all locking must be provided by the caller. @@ -496,6 +647,410 @@ void objagg_stats_put(const struct objagg_stats *objagg_stats) } EXPORT_SYMBOL(objagg_stats_put); +static struct objagg_hints_node * +objagg_hints_node_create(struct objagg_hints *objagg_hints, + struct objagg_obj *objagg_obj, size_t obj_size, + struct objagg_hints_node *parent_hnode) +{ + unsigned int user_count = objagg_obj->stats.user_count; + struct objagg_hints_node *hnode; + int err; + + hnode = kzalloc(sizeof(*hnode) + obj_size, GFP_KERNEL); + if (!hnode) + return ERR_PTR(-ENOMEM); + memcpy(hnode->obj, &objagg_obj->obj, obj_size); + hnode->stats_info.stats.user_count = user_count; + hnode->stats_info.stats.delta_user_count = user_count; + if (parent_hnode) { + parent_hnode->stats_info.stats.delta_user_count += user_count; + } else { + hnode->root_id = objagg_hints->root_count++; + hnode->stats_info.is_root = true; + } + hnode->stats_info.objagg_obj = objagg_obj; + + err = rhashtable_insert_fast(&objagg_hints->node_ht, &hnode->ht_node, + objagg_hints->ht_params); + if (err) + goto err_ht_insert; + + list_add(&hnode->list, &objagg_hints->node_list); + hnode->parent = parent_hnode; + objagg_hints->node_count++; + + return hnode; + +err_ht_insert: + kfree(hnode); + return ERR_PTR(err); +} + +static void objagg_hints_flush(struct objagg_hints *objagg_hints) +{ + struct objagg_hints_node *hnode, *tmp; + + list_for_each_entry_safe(hnode, tmp, &objagg_hints->node_list, list) { + list_del(&hnode->list); + rhashtable_remove_fast(&objagg_hints->node_ht, &hnode->ht_node, + objagg_hints->ht_params); + kfree(hnode); + } +} + +struct objagg_tmp_node { + struct objagg_obj *objagg_obj; + bool crossed_out; +}; + +struct objagg_tmp_graph { + struct objagg_tmp_node *nodes; + unsigned long nodes_count; + unsigned long *edges; +}; + +static int objagg_tmp_graph_edge_index(struct objagg_tmp_graph *graph, + int parent_index, int index) +{ + return index * graph->nodes_count + parent_index; +} + +static void objagg_tmp_graph_edge_set(struct objagg_tmp_graph *graph, + int parent_index, int index) +{ + int edge_index = objagg_tmp_graph_edge_index(graph, index, + parent_index); + + __set_bit(edge_index, graph->edges); +} + +static bool objagg_tmp_graph_is_edge(struct objagg_tmp_graph *graph, + int parent_index, int index) +{ + int edge_index = objagg_tmp_graph_edge_index(graph, index, + parent_index); + + return test_bit(edge_index, graph->edges); +} + +static unsigned int objagg_tmp_graph_node_weight(struct objagg_tmp_graph *graph, + unsigned int index) +{ + struct objagg_tmp_node *node = &graph->nodes[index]; + unsigned int weight = node->objagg_obj->stats.user_count; + int j; + + /* Node weight is sum of node users and all other nodes users + * that this node can represent with delta. + */ + + for (j = 0; j < graph->nodes_count; j++) { + if (!objagg_tmp_graph_is_edge(graph, index, j)) + continue; + node = &graph->nodes[j]; + if (node->crossed_out) + continue; + weight += node->objagg_obj->stats.user_count; + } + return weight; +} + +static int objagg_tmp_graph_node_max_weight(struct objagg_tmp_graph *graph) +{ + struct objagg_tmp_node *node; + unsigned int max_weight = 0; + unsigned int weight; + int max_index = -1; + int i; + + for (i = 0; i < graph->nodes_count; i++) { + node = &graph->nodes[i]; + if (node->crossed_out) + continue; + weight = objagg_tmp_graph_node_weight(graph, i); + if (weight >= max_weight) { + max_weight = weight; + max_index = i; + } + } + return max_index; +} + +static struct objagg_tmp_graph *objagg_tmp_graph_create(struct objagg *objagg) +{ + unsigned int nodes_count = objagg->obj_count; + struct objagg_tmp_graph *graph; + struct objagg_tmp_node *node; + struct objagg_tmp_node *pnode; + struct objagg_obj *objagg_obj; + size_t alloc_size; + int i, j; + + graph = kzalloc(sizeof(*graph), GFP_KERNEL); + if (!graph) + return NULL; + + graph->nodes = kcalloc(nodes_count, sizeof(*graph->nodes), GFP_KERNEL); + if (!graph->nodes) + goto err_nodes_alloc; + graph->nodes_count = nodes_count; + + alloc_size = BITS_TO_LONGS(nodes_count * nodes_count) * + sizeof(unsigned long); + graph->edges = kzalloc(alloc_size, GFP_KERNEL); + if (!graph->edges) + goto err_edges_alloc; + + i = 0; + list_for_each_entry(objagg_obj, &objagg->obj_list, list) { + node = &graph->nodes[i++]; + node->objagg_obj = objagg_obj; + } + + /* Assemble a temporary graph. Insert edge X->Y in case Y can be + * in delta of X. + */ + for (i = 0; i < nodes_count; i++) { + for (j = 0; j < nodes_count; j++) { + if (i == j) + continue; + pnode = &graph->nodes[i]; + node = &graph->nodes[j]; + if (objagg->ops->delta_check(objagg->priv, + pnode->objagg_obj->obj, + node->objagg_obj->obj)) { + objagg_tmp_graph_edge_set(graph, i, j); + + } + } + } + return graph; + +err_edges_alloc: + kfree(graph->nodes); +err_nodes_alloc: + kfree(graph); + return NULL; +} + +static void objagg_tmp_graph_destroy(struct objagg_tmp_graph *graph) +{ + kfree(graph->edges); + kfree(graph->nodes); + kfree(graph); +} + +static int +objagg_opt_simple_greedy_fillup_hints(struct objagg_hints *objagg_hints, + struct objagg *objagg) +{ + struct objagg_hints_node *hnode, *parent_hnode; + struct objagg_tmp_graph *graph; + struct objagg_tmp_node *node; + int index; + int j; + int err; + + graph = objagg_tmp_graph_create(objagg); + if (!graph) + return -ENOMEM; + + /* Find the nodes from the ones that can accommodate most users + * and cross them out of the graph. Save them to the hint list. + */ + while ((index = objagg_tmp_graph_node_max_weight(graph)) != -1) { + node = &graph->nodes[index]; + node->crossed_out = true; + hnode = objagg_hints_node_create(objagg_hints, + node->objagg_obj, + objagg->ops->obj_size, + NULL); + if (IS_ERR(hnode)) { + err = PTR_ERR(hnode); + goto out; + } + parent_hnode = hnode; + for (j = 0; j < graph->nodes_count; j++) { + if (!objagg_tmp_graph_is_edge(graph, index, j)) + continue; + node = &graph->nodes[j]; + if (node->crossed_out) + continue; + node->crossed_out = true; + hnode = objagg_hints_node_create(objagg_hints, + node->objagg_obj, + objagg->ops->obj_size, + parent_hnode); + if (IS_ERR(hnode)) { + err = PTR_ERR(hnode); + goto out; + } + } + } + + err = 0; +out: + objagg_tmp_graph_destroy(graph); + return err; +} + +struct objagg_opt_algo { + int (*fillup_hints)(struct objagg_hints *objagg_hints, + struct objagg *objagg); +}; + +static const struct objagg_opt_algo objagg_opt_simple_greedy = { + .fillup_hints = objagg_opt_simple_greedy_fillup_hints, +}; + + +static const struct objagg_opt_algo *objagg_opt_algos[] = { + [OBJAGG_OPT_ALGO_SIMPLE_GREEDY] = &objagg_opt_simple_greedy, +}; + +static int objagg_hints_obj_cmp(struct rhashtable_compare_arg *arg, + const void *obj) +{ + struct rhashtable *ht = arg->ht; + struct objagg_hints *objagg_hints = + container_of(ht, struct objagg_hints, node_ht); + const struct objagg_ops *ops = objagg_hints->ops; + const char *ptr = obj; + + ptr += ht->p.key_offset; + return ops->hints_obj_cmp ? ops->hints_obj_cmp(ptr, arg->key) : + memcmp(ptr, arg->key, ht->p.key_len); +} + +/** + * objagg_hints_get - obtains hints instance + * @objagg: objagg instance + * @opt_algo_type: type of hints finding algorithm + * + * Note: all locking must be provided by the caller. + * + * According to the algo type, the existing objects of objagg instance + * are going to be went-through to assemble an optimal tree. We call this + * tree hints. These hints can be later on used for creation of + * a new objagg instance. There, the future object creations are going + * to be consulted with these hints in order to find out, where exactly + * the new object should be put as a root or delta. + * + * Returns a pointer to hints instance in case of success, + * otherwise it returns pointer error using ERR_PTR macro. + */ +struct objagg_hints *objagg_hints_get(struct objagg *objagg, + enum objagg_opt_algo_type opt_algo_type) +{ + const struct objagg_opt_algo *algo = objagg_opt_algos[opt_algo_type]; + struct objagg_hints *objagg_hints; + int err; + + objagg_hints = kzalloc(sizeof(*objagg_hints), GFP_KERNEL); + if (!objagg_hints) + return ERR_PTR(-ENOMEM); + + objagg_hints->ops = objagg->ops; + objagg_hints->refcount = 1; + + INIT_LIST_HEAD(&objagg_hints->node_list); + + objagg_hints->ht_params.key_len = objagg->ops->obj_size; + objagg_hints->ht_params.key_offset = + offsetof(struct objagg_hints_node, obj); + objagg_hints->ht_params.head_offset = + offsetof(struct objagg_hints_node, ht_node); + objagg_hints->ht_params.obj_cmpfn = objagg_hints_obj_cmp; + + err = rhashtable_init(&objagg_hints->node_ht, &objagg_hints->ht_params); + if (err) + goto err_rhashtable_init; + + err = algo->fillup_hints(objagg_hints, objagg); + if (err) + goto err_fillup_hints; + + if (WARN_ON(objagg_hints->node_count != objagg->obj_count)) { + err = -EINVAL; + goto err_node_count_check; + } + + return objagg_hints; + +err_node_count_check: +err_fillup_hints: + objagg_hints_flush(objagg_hints); + rhashtable_destroy(&objagg_hints->node_ht); +err_rhashtable_init: + kfree(objagg_hints); + return ERR_PTR(err); +} +EXPORT_SYMBOL(objagg_hints_get); + +/** + * objagg_hints_put - puts hints instance + * @objagg_hints: objagg hints instance + * + * Note: all locking must be provided by the caller. + */ +void objagg_hints_put(struct objagg_hints *objagg_hints) +{ + if (--objagg_hints->refcount) + return; + objagg_hints_flush(objagg_hints); + rhashtable_destroy(&objagg_hints->node_ht); + kfree(objagg_hints); +} +EXPORT_SYMBOL(objagg_hints_put); + +/** + * objagg_hints_stats_get - obtains stats of the hints instance + * @objagg_hints: hints instance + * + * Note: all locking must be provided by the caller. + * + * The returned structure contains statistics of all objects + * currently in use, ordered by following rules: + * 1) Root objects are always on lower indexes than the rest. + * 2) Objects with higher delta user count are always on lower + * indexes. + * 3) In case multiple objects have the same delta user count, + * the objects are ordered by user count. + * + * Returns a pointer to stats instance in case of success, + * otherwise it returns pointer error using ERR_PTR macro. + */ +const struct objagg_stats * +objagg_hints_stats_get(struct objagg_hints *objagg_hints) +{ + struct objagg_stats *objagg_stats; + struct objagg_hints_node *hnode; + int i; + + objagg_stats = kzalloc(struct_size(objagg_stats, stats_info, + objagg_hints->node_count), + GFP_KERNEL); + if (!objagg_stats) + return ERR_PTR(-ENOMEM); + + i = 0; + list_for_each_entry(hnode, &objagg_hints->node_list, list) { + memcpy(&objagg_stats->stats_info[i], &hnode->stats_info, + sizeof(objagg_stats->stats_info[0])); + if (objagg_stats->stats_info[i].is_root) + objagg_stats->root_count++; + i++; + } + objagg_stats->stats_info_count = i; + + sort(objagg_stats->stats_info, objagg_stats->stats_info_count, + sizeof(struct objagg_obj_stats_info), + objagg_stats_info_sort_cmp_func, NULL); + + return objagg_stats; +} +EXPORT_SYMBOL(objagg_hints_stats_get); + MODULE_LICENSE("Dual BSD/GPL"); MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>"); MODULE_DESCRIPTION("Object aggregation manager"); diff --git a/lib/rhashtable.c b/lib/rhashtable.c index 852ffa5160f1..0a105d4af166 100644 --- a/lib/rhashtable.c +++ b/lib/rhashtable.c @@ -682,7 +682,7 @@ EXPORT_SYMBOL_GPL(rhashtable_walk_enter); * rhashtable_walk_exit - Free an iterator * @iter: Hash table Iterator * - * This function frees resources allocated by rhashtable_walk_init. + * This function frees resources allocated by rhashtable_walk_enter. */ void rhashtable_walk_exit(struct rhashtable_iter *iter) { diff --git a/lib/sbitmap.c b/lib/sbitmap.c index 65c2d06250a6..5b382c1244ed 100644 --- a/lib/sbitmap.c +++ b/lib/sbitmap.c @@ -26,14 +26,10 @@ static inline bool sbitmap_deferred_clear(struct sbitmap *sb, int index) { unsigned long mask, val; - unsigned long __maybe_unused flags; bool ret = false; + unsigned long flags; - /* Silence bogus lockdep warning */ -#if defined(CONFIG_LOCKDEP) - local_irq_save(flags); -#endif - spin_lock(&sb->map[index].swap_lock); + spin_lock_irqsave(&sb->map[index].swap_lock, flags); if (!sb->map[index].cleared) goto out_unlock; @@ -54,10 +50,7 @@ static inline bool sbitmap_deferred_clear(struct sbitmap *sb, int index) ret = true; out_unlock: - spin_unlock(&sb->map[index].swap_lock); -#if defined(CONFIG_LOCKDEP) - local_irq_restore(flags); -#endif + spin_unlock_irqrestore(&sb->map[index].swap_lock, flags); return ret; } diff --git a/lib/test_bpf.c b/lib/test_bpf.c index f3e570722a7e..0845f635f404 100644 --- a/lib/test_bpf.c +++ b/lib/test_bpf.c @@ -6668,12 +6668,14 @@ static int __run_one(const struct bpf_prog *fp, const void *data, u64 start, finish; int ret = 0, i; + preempt_disable(); start = ktime_get_ns(); for (i = 0; i < runs; i++) ret = BPF_PROG_RUN(fp, data); finish = ktime_get_ns(); + preempt_enable(); *duration = finish - start; do_div(*duration, runs); diff --git a/lib/test_kmod.c b/lib/test_kmod.c index d82d022111e0..9cf77628fc91 100644 --- a/lib/test_kmod.c +++ b/lib/test_kmod.c @@ -632,7 +632,7 @@ static void __kmod_config_free(struct test_config *config) config->test_driver = NULL; kfree_const(config->test_fs); - config->test_driver = NULL; + config->test_fs = NULL; } static void kmod_config_free(struct kmod_test_device *test_dev) diff --git a/lib/test_objagg.c b/lib/test_objagg.c index ab57144bb0cd..72c1abfa154d 100644 --- a/lib/test_objagg.c +++ b/lib/test_objagg.c @@ -87,6 +87,15 @@ static void world_obj_put(struct world *world, struct objagg *objagg, #define MAX_KEY_ID_DIFF 5 +static bool delta_check(void *priv, const void *parent_obj, const void *obj) +{ + const struct tokey *parent_key = parent_obj; + const struct tokey *key = obj; + int diff = key->id - parent_key->id; + + return diff >= 0 && diff <= MAX_KEY_ID_DIFF; +} + static void *delta_create(void *priv, void *parent_obj, void *obj) { struct tokey *parent_key = parent_obj; @@ -95,7 +104,7 @@ static void *delta_create(void *priv, void *parent_obj, void *obj) int diff = key->id - parent_key->id; struct delta *delta; - if (diff < 0 || diff > MAX_KEY_ID_DIFF) + if (!delta_check(priv, parent_obj, obj)) return ERR_PTR(-EINVAL); delta = kzalloc(sizeof(*delta), GFP_KERNEL); @@ -115,7 +124,7 @@ static void delta_destroy(void *priv, void *delta_priv) kfree(delta); } -static void *root_create(void *priv, void *obj) +static void *root_create(void *priv, void *obj, unsigned int id) { struct world *world = priv; struct tokey *key = obj; @@ -268,6 +277,12 @@ stats_put: return err; } +static bool delta_check_dummy(void *priv, const void *parent_obj, + const void *obj) +{ + return false; +} + static void *delta_create_dummy(void *priv, void *parent_obj, void *obj) { return ERR_PTR(-EOPNOTSUPP); @@ -279,6 +294,7 @@ static void delta_destroy_dummy(void *priv, void *delta_priv) static const struct objagg_ops nodelta_ops = { .obj_size = sizeof(struct tokey), + .delta_check = delta_check_dummy, .delta_create = delta_create_dummy, .delta_destroy = delta_destroy_dummy, .root_create = root_create, @@ -292,7 +308,7 @@ static int test_nodelta(void) int i; int err; - objagg = objagg_create(&nodelta_ops, &world); + objagg = objagg_create(&nodelta_ops, NULL, &world); if (IS_ERR(objagg)) return PTR_ERR(objagg); @@ -357,6 +373,7 @@ err_stats_second_zero: static const struct objagg_ops delta_ops = { .obj_size = sizeof(struct tokey), + .delta_check = delta_check, .delta_create = delta_create, .delta_destroy = delta_destroy, .root_create = root_create, @@ -728,8 +745,10 @@ static int check_expect_stats(struct objagg *objagg, int err; stats = objagg_stats_get(objagg); - if (IS_ERR(stats)) + if (IS_ERR(stats)) { + *errmsg = "objagg_stats_get() failed."; return PTR_ERR(stats); + } err = __check_expect_stats(stats, expect_stats, errmsg); objagg_stats_put(stats); return err; @@ -769,7 +788,6 @@ static int test_delta_action_item(struct world *world, if (err) goto errout; - errmsg = NULL; err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg); if (err) { pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg); @@ -793,7 +811,7 @@ static int test_delta(void) int i; int err; - objagg = objagg_create(&delta_ops, &world); + objagg = objagg_create(&delta_ops, NULL, &world); if (IS_ERR(objagg)) return PTR_ERR(objagg); @@ -815,6 +833,170 @@ err_do_action_item: return err; } +struct hints_case { + const unsigned int *key_ids; + size_t key_ids_count; + struct expect_stats expect_stats; + struct expect_stats expect_stats_hints; +}; + +static const unsigned int hints_case_key_ids[] = { + 1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8, +}; + +static const struct hints_case hints_case = { + .key_ids = hints_case_key_ids, + .key_ids_count = ARRAY_SIZE(hints_case_key_ids), + .expect_stats = + EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1), + DELTA(8, 3), DELTA(3, 2), + DELTA(5, 2), DELTA(6, 1)), + .expect_stats_hints = + EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1), + DELTA(8, 3), DELTA(5, 2), + DELTA(6, 1), DELTA(7, 1)), +}; + +static void __pr_debug_stats(const struct objagg_stats *stats) +{ + int i; + + for (i = 0; i < stats->stats_info_count; i++) + pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i, + obj_to_key_id(stats->stats_info[i].objagg_obj), + stats->stats_info[i].stats.user_count, + stats->stats_info[i].stats.delta_user_count, + stats->stats_info[i].is_root ? "root" : "noroot"); +} + +static void pr_debug_stats(struct objagg *objagg) +{ + const struct objagg_stats *stats; + + stats = objagg_stats_get(objagg); + if (IS_ERR(stats)) + return; + __pr_debug_stats(stats); + objagg_stats_put(stats); +} + +static void pr_debug_hints_stats(struct objagg_hints *objagg_hints) +{ + const struct objagg_stats *stats; + + stats = objagg_hints_stats_get(objagg_hints); + if (IS_ERR(stats)) + return; + __pr_debug_stats(stats); + objagg_stats_put(stats); +} + +static int check_expect_hints_stats(struct objagg_hints *objagg_hints, + const struct expect_stats *expect_stats, + const char **errmsg) +{ + const struct objagg_stats *stats; + int err; + + stats = objagg_hints_stats_get(objagg_hints); + if (IS_ERR(stats)) + return PTR_ERR(stats); + err = __check_expect_stats(stats, expect_stats, errmsg); + objagg_stats_put(stats); + return err; +} + +static int test_hints_case(const struct hints_case *hints_case) +{ + struct objagg_obj *objagg_obj; + struct objagg_hints *hints; + struct world world2 = {}; + struct world world = {}; + struct objagg *objagg2; + struct objagg *objagg; + const char *errmsg; + int i; + int err; + + objagg = objagg_create(&delta_ops, NULL, &world); + if (IS_ERR(objagg)) + return PTR_ERR(objagg); + + for (i = 0; i < hints_case->key_ids_count; i++) { + objagg_obj = world_obj_get(&world, objagg, + hints_case->key_ids[i]); + if (IS_ERR(objagg_obj)) { + err = PTR_ERR(objagg_obj); + goto err_world_obj_get; + } + } + + pr_debug_stats(objagg); + err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg); + if (err) { + pr_err("Stats: %s\n", errmsg); + goto err_check_expect_stats; + } + + hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY); + if (IS_ERR(hints)) { + err = PTR_ERR(hints); + goto err_hints_get; + } + + pr_debug_hints_stats(hints); + err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints, + &errmsg); + if (err) { + pr_err("Hints stats: %s\n", errmsg); + goto err_check_expect_hints_stats; + } + + objagg2 = objagg_create(&delta_ops, hints, &world2); + if (IS_ERR(objagg2)) + return PTR_ERR(objagg2); + + for (i = 0; i < hints_case->key_ids_count; i++) { + objagg_obj = world_obj_get(&world2, objagg2, + hints_case->key_ids[i]); + if (IS_ERR(objagg_obj)) { + err = PTR_ERR(objagg_obj); + goto err_world2_obj_get; + } + } + + pr_debug_stats(objagg2); + err = check_expect_stats(objagg2, &hints_case->expect_stats_hints, + &errmsg); + if (err) { + pr_err("Stats2: %s\n", errmsg); + goto err_check_expect_stats2; + } + + err = 0; + +err_check_expect_stats2: +err_world2_obj_get: + for (i--; i >= 0; i--) + world_obj_put(&world2, objagg, hints_case->key_ids[i]); + objagg_hints_put(hints); + objagg_destroy(objagg2); + i = hints_case->key_ids_count; +err_check_expect_hints_stats: +err_hints_get: +err_check_expect_stats: +err_world_obj_get: + for (i--; i >= 0; i--) + world_obj_put(&world, objagg, hints_case->key_ids[i]); + + objagg_destroy(objagg); + return err; +} +static int test_hints(void) +{ + return test_hints_case(&hints_case); +} + static int __init test_objagg_init(void) { int err; @@ -822,7 +1004,10 @@ static int __init test_objagg_init(void) err = test_nodelta(); if (err) return err; - return test_delta(); + err = test_delta(); + if (err) + return err; + return test_hints(); } static void __exit test_objagg_exit(void) diff --git a/lib/test_rhashtable.c b/lib/test_rhashtable.c index 6a8ac7626797..3bd2e91bfc29 100644 --- a/lib/test_rhashtable.c +++ b/lib/test_rhashtable.c @@ -177,16 +177,11 @@ static int __init test_rht_lookup(struct rhashtable *ht, struct test_obj *array, static void test_bucket_stats(struct rhashtable *ht, unsigned int entries) { - unsigned int err, total = 0, chain_len = 0; + unsigned int total = 0, chain_len = 0; struct rhashtable_iter hti; struct rhash_head *pos; - err = rhashtable_walk_init(ht, &hti, GFP_KERNEL); - if (err) { - pr_warn("Test failed: allocation error"); - return; - } - + rhashtable_walk_enter(ht, &hti); rhashtable_walk_start(&hti); while ((pos = rhashtable_walk_next(&hti))) { @@ -395,7 +390,7 @@ static int __init test_rhltable(unsigned int entries) if (WARN(err, "cannot remove element at slot %d", i)) continue; } else { - if (WARN(err != -ENOENT, "removed non-existant element %d, error %d not %d", + if (WARN(err != -ENOENT, "removed non-existent element %d, error %d not %d", i, err, -ENOENT)) continue; } @@ -440,7 +435,7 @@ static int __init test_rhltable(unsigned int entries) if (WARN(err, "cannot remove element at slot %d", i)) continue; } else { - if (WARN(err != -ENOENT, "removed non-existant element, error %d not %d", + if (WARN(err != -ENOENT, "removed non-existent element, error %d not %d", err, -ENOENT)) continue; } @@ -541,38 +536,45 @@ static unsigned int __init print_ht(struct rhltable *rhlt) static int __init test_insert_dup(struct test_obj_rhl *rhl_test_objects, int cnt, bool slow) { - struct rhltable rhlt; + struct rhltable *rhlt; unsigned int i, ret; const char *key; int err = 0; - err = rhltable_init(&rhlt, &test_rht_params_dup); - if (WARN_ON(err)) + rhlt = kmalloc(sizeof(*rhlt), GFP_KERNEL); + if (WARN_ON(!rhlt)) + return -EINVAL; + + err = rhltable_init(rhlt, &test_rht_params_dup); + if (WARN_ON(err)) { + kfree(rhlt); return err; + } for (i = 0; i < cnt; i++) { rhl_test_objects[i].value.tid = i; - key = rht_obj(&rhlt.ht, &rhl_test_objects[i].list_node.rhead); + key = rht_obj(&rhlt->ht, &rhl_test_objects[i].list_node.rhead); key += test_rht_params_dup.key_offset; if (slow) { - err = PTR_ERR(rhashtable_insert_slow(&rhlt.ht, key, + err = PTR_ERR(rhashtable_insert_slow(&rhlt->ht, key, &rhl_test_objects[i].list_node.rhead)); if (err == -EAGAIN) err = 0; } else - err = rhltable_insert(&rhlt, + err = rhltable_insert(rhlt, &rhl_test_objects[i].list_node, test_rht_params_dup); if (WARN(err, "error %d on element %d/%d (%s)\n", err, i, cnt, slow? "slow" : "fast")) goto skip_print; } - ret = print_ht(&rhlt); + ret = print_ht(rhlt); WARN(ret != cnt, "missing rhltable elements (%d != %d, %s)\n", ret, cnt, slow? "slow" : "fast"); skip_print: - rhltable_destroy(&rhlt); + rhltable_destroy(rhlt); + kfree(rhlt); return 0; } diff --git a/lib/test_xarray.c b/lib/test_xarray.c index 4676c0a1eeca..c596a957f764 100644 --- a/lib/test_xarray.c +++ b/lib/test_xarray.c @@ -199,7 +199,7 @@ static noinline void check_xa_mark_1(struct xarray *xa, unsigned long index) XA_BUG_ON(xa, xa_store_index(xa, index + 1, GFP_KERNEL)); xa_set_mark(xa, index + 1, XA_MARK_0); XA_BUG_ON(xa, xa_store_index(xa, index + 2, GFP_KERNEL)); - xa_set_mark(xa, index + 2, XA_MARK_1); + xa_set_mark(xa, index + 2, XA_MARK_2); XA_BUG_ON(xa, xa_store_index(xa, next, GFP_KERNEL)); xa_store_order(xa, index, order, xa_mk_index(index), GFP_KERNEL); @@ -209,8 +209,8 @@ static noinline void check_xa_mark_1(struct xarray *xa, unsigned long index) void *entry; XA_BUG_ON(xa, !xa_get_mark(xa, i, XA_MARK_0)); - XA_BUG_ON(xa, !xa_get_mark(xa, i, XA_MARK_1)); - XA_BUG_ON(xa, xa_get_mark(xa, i, XA_MARK_2)); + XA_BUG_ON(xa, xa_get_mark(xa, i, XA_MARK_1)); + XA_BUG_ON(xa, !xa_get_mark(xa, i, XA_MARK_2)); /* We should see two elements in the array */ rcu_read_lock(); @@ -357,7 +357,7 @@ static noinline void check_cmpxchg(struct xarray *xa) static noinline void check_reserve(struct xarray *xa) { void *entry; - unsigned long index = 0; + unsigned long index; /* An array with a reserved entry is not empty */ XA_BUG_ON(xa, !xa_empty(xa)); @@ -382,10 +382,12 @@ static noinline void check_reserve(struct xarray *xa) xa_erase_index(xa, 12345678); XA_BUG_ON(xa, !xa_empty(xa)); - /* And so does xa_insert */ + /* But xa_insert does not */ xa_reserve(xa, 12345678, GFP_KERNEL); - XA_BUG_ON(xa, xa_insert(xa, 12345678, xa_mk_value(12345678), 0) != 0); - xa_erase_index(xa, 12345678); + XA_BUG_ON(xa, xa_insert(xa, 12345678, xa_mk_value(12345678), 0) != + -EEXIST); + XA_BUG_ON(xa, xa_empty(xa)); + XA_BUG_ON(xa, xa_erase(xa, 12345678) != NULL); XA_BUG_ON(xa, !xa_empty(xa)); /* Can iterate through a reserved entry */ @@ -393,7 +395,7 @@ static noinline void check_reserve(struct xarray *xa) xa_reserve(xa, 6, GFP_KERNEL); xa_store_index(xa, 7, GFP_KERNEL); - xa_for_each(xa, entry, index, ULONG_MAX, XA_PRESENT) { + xa_for_each(xa, index, entry) { XA_BUG_ON(xa, index != 5 && index != 7); } xa_destroy(xa); @@ -812,17 +814,16 @@ static noinline void check_find_1(struct xarray *xa) static noinline void check_find_2(struct xarray *xa) { void *entry; - unsigned long i, j, index = 0; + unsigned long i, j, index; - xa_for_each(xa, entry, index, ULONG_MAX, XA_PRESENT) { + xa_for_each(xa, index, entry) { XA_BUG_ON(xa, true); } for (i = 0; i < 1024; i++) { xa_store_index(xa, index, GFP_KERNEL); j = 0; - index = 0; - xa_for_each(xa, entry, index, ULONG_MAX, XA_PRESENT) { + xa_for_each(xa, index, entry) { XA_BUG_ON(xa, xa_mk_index(index) != entry); XA_BUG_ON(xa, index != j++); } @@ -839,6 +840,7 @@ static noinline void check_find_3(struct xarray *xa) for (i = 0; i < 100; i++) { for (j = 0; j < 100; j++) { + rcu_read_lock(); for (k = 0; k < 100; k++) { xas_set(&xas, j); xas_for_each_marked(&xas, entry, k, XA_MARK_0) @@ -847,6 +849,7 @@ static noinline void check_find_3(struct xarray *xa) XA_BUG_ON(xa, xas.xa_node != XAS_RESTART); } + rcu_read_unlock(); } xa_store_index(xa, i, GFP_KERNEL); xa_set_mark(xa, i, XA_MARK_0); @@ -1183,6 +1186,35 @@ static noinline void check_store_range(struct xarray *xa) } } +static void check_align_1(struct xarray *xa, char *name) +{ + int i; + unsigned int id; + unsigned long index; + void *entry; + + for (i = 0; i < 8; i++) { + id = 0; + XA_BUG_ON(xa, xa_alloc(xa, &id, UINT_MAX, name + i, GFP_KERNEL) + != 0); + XA_BUG_ON(xa, id != i); + } + xa_for_each(xa, index, entry) + XA_BUG_ON(xa, xa_is_err(entry)); + xa_destroy(xa); +} + +static noinline void check_align(struct xarray *xa) +{ + char name[] = "Motorola 68000"; + + check_align_1(xa, name); + check_align_1(xa, name + 1); + check_align_1(xa, name + 2); + check_align_1(xa, name + 3); +// check_align_2(xa, name); +} + static LIST_HEAD(shadow_nodes); static void test_update_node(struct xa_node *node) @@ -1332,6 +1364,7 @@ static int xarray_checks(void) check_create_range(&array); check_store_range(&array); check_store_iter(&array); + check_align(&xa0); check_workingset(&array, 0); check_workingset(&array, 64); diff --git a/lib/xarray.c b/lib/xarray.c index 5f3f9311de89..81c3171ddde9 100644 --- a/lib/xarray.c +++ b/lib/xarray.c @@ -232,6 +232,8 @@ void *xas_load(struct xa_state *xas) if (xas->xa_shift > node->shift) break; entry = xas_descend(xas, node); + if (node->shift == 0) + break; } return entry; } @@ -506,7 +508,7 @@ static void xas_free_nodes(struct xa_state *xas, struct xa_node *top) for (;;) { void *entry = xa_entry_locked(xas->xa, node, offset); - if (xa_is_node(entry)) { + if (node->shift && xa_is_node(entry)) { node = xa_to_node(entry); offset = 0; continue; @@ -604,6 +606,7 @@ static int xas_expand(struct xa_state *xas, void *head) /* * xas_create() - Create a slot to store an entry in. * @xas: XArray operation state. + * @allow_root: %true if we can store the entry in the root directly * * Most users will not need to call this function directly, as it is called * by xas_store(). It is useful for doing conditional store operations @@ -613,7 +616,7 @@ static int xas_expand(struct xa_state *xas, void *head) * If the slot was newly created, returns %NULL. If it failed to create the * slot, returns %NULL and indicates the error in @xas. */ -static void *xas_create(struct xa_state *xas) +static void *xas_create(struct xa_state *xas, bool allow_root) { struct xarray *xa = xas->xa; void *entry; @@ -628,6 +631,8 @@ static void *xas_create(struct xa_state *xas) shift = xas_expand(xas, entry); if (shift < 0) return NULL; + if (!shift && !allow_root) + shift = XA_CHUNK_SHIFT; entry = xa_head_locked(xa); slot = &xa->xa_head; } else if (xas_error(xas)) { @@ -687,7 +692,7 @@ void xas_create_range(struct xa_state *xas) xas->xa_sibs = 0; for (;;) { - xas_create(xas); + xas_create(xas, true); if (xas_error(xas)) goto restore; if (xas->xa_index <= (index | XA_CHUNK_MASK)) @@ -754,7 +759,7 @@ void *xas_store(struct xa_state *xas, void *entry) bool value = xa_is_value(entry); if (entry) - first = xas_create(xas); + first = xas_create(xas, !xa_is_node(entry)); else first = xas_load(xas); @@ -1251,35 +1256,6 @@ void *xas_find_conflict(struct xa_state *xas) EXPORT_SYMBOL_GPL(xas_find_conflict); /** - * xa_init_flags() - Initialise an empty XArray with flags. - * @xa: XArray. - * @flags: XA_FLAG values. - * - * If you need to initialise an XArray with special flags (eg you need - * to take the lock from interrupt context), use this function instead - * of xa_init(). - * - * Context: Any context. - */ -void xa_init_flags(struct xarray *xa, gfp_t flags) -{ - unsigned int lock_type; - static struct lock_class_key xa_lock_irq; - static struct lock_class_key xa_lock_bh; - - spin_lock_init(&xa->xa_lock); - xa->xa_flags = flags; - xa->xa_head = NULL; - - lock_type = xa_lock_type(xa); - if (lock_type == XA_LOCK_IRQ) - lockdep_set_class(&xa->xa_lock, &xa_lock_irq); - else if (lock_type == XA_LOCK_BH) - lockdep_set_class(&xa->xa_lock, &xa_lock_bh); -} -EXPORT_SYMBOL(xa_init_flags); - -/** * xa_load() - Load an entry from an XArray. * @xa: XArray. * @index: index into array. @@ -1308,7 +1284,6 @@ static void *xas_result(struct xa_state *xas, void *curr) { if (xa_is_zero(curr)) return NULL; - XA_NODE_BUG_ON(xas->xa_node, xa_is_internal(curr)); if (xas_error(xas)) curr = xas->xa_node; return curr; @@ -1378,7 +1353,7 @@ void *__xa_store(struct xarray *xa, unsigned long index, void *entry, gfp_t gfp) XA_STATE(xas, xa, index); void *curr; - if (WARN_ON_ONCE(xa_is_internal(entry))) + if (WARN_ON_ONCE(xa_is_advanced(entry))) return XA_ERROR(-EINVAL); if (xa_track_free(xa) && !entry) entry = XA_ZERO_ENTRY; @@ -1444,7 +1419,7 @@ void *__xa_cmpxchg(struct xarray *xa, unsigned long index, XA_STATE(xas, xa, index); void *curr; - if (WARN_ON_ONCE(xa_is_internal(entry))) + if (WARN_ON_ONCE(xa_is_advanced(entry))) return XA_ERROR(-EINVAL); if (xa_track_free(xa) && !entry) entry = XA_ZERO_ENTRY; @@ -1465,6 +1440,47 @@ void *__xa_cmpxchg(struct xarray *xa, unsigned long index, EXPORT_SYMBOL(__xa_cmpxchg); /** + * __xa_insert() - Store this entry in the XArray if no entry is present. + * @xa: XArray. + * @index: Index into array. + * @entry: New entry. + * @gfp: Memory allocation flags. + * + * Inserting a NULL entry will store a reserved entry (like xa_reserve()) + * if no entry is present. Inserting will fail if a reserved entry is + * present, even though loading from this index will return NULL. + * + * Context: Any context. Expects xa_lock to be held on entry. May + * release and reacquire xa_lock if @gfp flags permit. + * Return: 0 if the store succeeded. -EEXIST if another entry was present. + * -ENOMEM if memory could not be allocated. + */ +int __xa_insert(struct xarray *xa, unsigned long index, void *entry, gfp_t gfp) +{ + XA_STATE(xas, xa, index); + void *curr; + + if (WARN_ON_ONCE(xa_is_advanced(entry))) + return -EINVAL; + if (!entry) + entry = XA_ZERO_ENTRY; + + do { + curr = xas_load(&xas); + if (!curr) { + xas_store(&xas, entry); + if (xa_track_free(xa)) + xas_clear_mark(&xas, XA_FREE_MARK); + } else { + xas_set_err(&xas, -EEXIST); + } + } while (__xas_nomem(&xas, gfp)); + + return xas_error(&xas); +} +EXPORT_SYMBOL(__xa_insert); + +/** * __xa_reserve() - Reserve this index in the XArray. * @xa: XArray. * @index: Index into array. @@ -1567,7 +1583,7 @@ void *xa_store_range(struct xarray *xa, unsigned long first, if (last + 1) order = __ffs(last + 1); xas_set_order(&xas, last, order); - xas_create(&xas); + xas_create(&xas, true); if (xas_error(&xas)) goto unlock; } @@ -1609,7 +1625,7 @@ int __xa_alloc(struct xarray *xa, u32 *id, u32 max, void *entry, gfp_t gfp) XA_STATE(xas, xa, 0); int err; - if (WARN_ON_ONCE(xa_is_internal(entry))) + if (WARN_ON_ONCE(xa_is_advanced(entry))) return -EINVAL; if (WARN_ON_ONCE(!xa_track_free(xa))) return -EINVAL; |