diff options
Diffstat (limited to 'lib/sort.c')
-rw-r--r-- | lib/sort.c | 69 |
1 files changed, 53 insertions, 16 deletions
diff --git a/lib/sort.c b/lib/sort.c index cf408aec3733..b399bf10d675 100644 --- a/lib/sort.c +++ b/lib/sort.c @@ -51,7 +51,7 @@ static bool is_aligned(const void *base, size_t size, unsigned char align) * which basically all CPUs have, to minimize loop overhead computations. * * For some reason, on x86 gcc 7.3.0 adds a redundant test of n at the - * bottom of the loop, even though the zero flag is stil valid from the + * bottom of the loop, even though the zero flag is still valid from the * subtract (since the intervening mov instructions don't alter the flags). * Gcc 8.1.0 doesn't have that problem. */ @@ -117,23 +117,32 @@ static void swap_bytes(void *a, void *b, size_t n) } while (n); } -typedef void (*swap_func_t)(void *a, void *b, int size); - /* * The values are arbitrary as long as they can't be confused with * a pointer, but small integers make for the smallest compare * instructions. */ -#define SWAP_WORDS_64 (swap_func_t)0 -#define SWAP_WORDS_32 (swap_func_t)1 -#define SWAP_BYTES (swap_func_t)2 +#define SWAP_WORDS_64 (swap_r_func_t)0 +#define SWAP_WORDS_32 (swap_r_func_t)1 +#define SWAP_BYTES (swap_r_func_t)2 +#define SWAP_WRAPPER (swap_r_func_t)3 + +struct wrapper { + cmp_func_t cmp; + swap_func_t swap; +}; /* * The function pointer is last to make tail calls most efficient if the * compiler decides not to inline this function. */ -static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func) +static void do_swap(void *a, void *b, size_t size, swap_r_func_t swap_func, const void *priv) { + if (swap_func == SWAP_WRAPPER) { + ((const struct wrapper *)priv)->swap(a, b, (int)size); + return; + } + if (swap_func == SWAP_WORDS_64) swap_words_64(a, b, size); else if (swap_func == SWAP_WORDS_32) @@ -141,7 +150,16 @@ static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func) else if (swap_func == SWAP_BYTES) swap_bytes(a, b, size); else - swap_func(a, b, (int)size); + swap_func(a, b, (int)size, priv); +} + +#define _CMP_WRAPPER ((cmp_r_func_t)0L) + +static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv) +{ + if (cmp == _CMP_WRAPPER) + return ((const struct wrapper *)priv)->cmp(a, b); + return cmp(a, b, priv); } /** @@ -171,12 +189,13 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size) } /** - * sort - sort an array of elements + * sort_r - sort an array of elements * @base: pointer to data to sort * @num: number of elements * @size: size of each element * @cmp_func: pointer to comparison function * @swap_func: pointer to swap function or NULL + * @priv: third argument passed to comparison function * * This function does a heapsort on the given array. You may provide * a swap_func function if you need to do something more than a memory @@ -188,9 +207,10 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size) * O(n*n) worst-case behavior and extra memory requirements that make * it less suitable for kernel use. */ -void sort(void *base, size_t num, size_t size, - int (*cmp_func)(const void *, const void *), - void (*swap_func)(void *, void *, int size)) +void sort_r(void *base, size_t num, size_t size, + cmp_r_func_t cmp_func, + swap_r_func_t swap_func, + const void *priv) { /* pre-scale counters for performance */ size_t n = num * size, a = (num/2) * size; @@ -199,6 +219,10 @@ void sort(void *base, size_t num, size_t size, if (!a) /* num < 2 || size == 0 */ return; + /* called from 'sort' without swap function, let's pick the default */ + if (swap_func == SWAP_WRAPPER && !((struct wrapper *)priv)->swap) + swap_func = NULL; + if (!swap_func) { if (is_aligned(base, size, 8)) swap_func = SWAP_WORDS_64; @@ -221,7 +245,7 @@ void sort(void *base, size_t num, size_t size, if (a) /* Building heap: sift down --a */ a -= size; else if (n -= size) /* Sorting: Extract root to --n */ - do_swap(base, base + n, size, swap_func); + do_swap(base, base + n, size, swap_func, priv); else /* Sort complete */ break; @@ -238,18 +262,31 @@ void sort(void *base, size_t num, size_t size, * average, 3/4 worst-case.) */ for (b = a; c = 2*b + size, (d = c + size) < n;) - b = cmp_func(base + c, base + d) >= 0 ? c : d; + b = do_cmp(base + c, base + d, cmp_func, priv) >= 0 ? c : d; if (d == n) /* Special case last leaf with no sibling */ b = c; /* Now backtrack from "b" to the correct location for "a" */ - while (b != a && cmp_func(base + a, base + b) >= 0) + while (b != a && do_cmp(base + a, base + b, cmp_func, priv) >= 0) b = parent(b, lsbit, size); c = b; /* Where "a" belongs */ while (b != a) { /* Shift it into place */ b = parent(b, lsbit, size); - do_swap(base + b, base + c, size, swap_func); + do_swap(base + b, base + c, size, swap_func, priv); } } } +EXPORT_SYMBOL(sort_r); + +void sort(void *base, size_t num, size_t size, + cmp_func_t cmp_func, + swap_func_t swap_func) +{ + struct wrapper w = { + .cmp = cmp_func, + .swap = swap_func, + }; + + return sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w); +} EXPORT_SYMBOL(sort); |