Source code for quickselect.hoare

"""
Based on "quickselect" selection algorithm by Tony Hoare.

Reference:
    https://en.wikipedia.org/wiki/Quickselect
"""

from __future__ import annotations

from operator import gt, lt
from typing import Any, Callable, MutableSequence, Sequence

from .core.utils import SequenceKeyView as _SequenceKeyView
from .hints import Comparator, Domain, Key


[docs] def nth_largest( sequence: MutableSequence[Domain], n: int, *, key: Key[Domain] | None = None, ) -> Domain: """ Returns n-th largest element and partially sorts given sequence while searching. +------------+-------------+-----------------+------------------+ | complexity | best | average | worst | +------------+-------------+-----------------+------------------+ | time | ``O(size)`` | ``O(size)`` | ``O(size ** 2)`` | +------------+-------------+-----------------+------------------+ | memory | ``O(1)`` | ``O(log size)`` | ``O(size)`` | +------------+-------------+-----------------+------------------+ where ``size = len(sequence)``. :param sequence: sequence to search in :param n: index of the element to search for in the sequence sorted by key in descending order (e.g. ``n = 0`` corresponds to the maximum element) :param key: single argument ordering function, if none is specified compares elements themselves :returns: n-th largest element of the sequence >>> sequence = list(range(-10, 11)) >>> nth_largest(sequence, 0) 10 >>> nth_largest(sequence, 1) 9 >>> nth_largest(sequence, 19) -9 >>> nth_largest(sequence, 20) -10 >>> nth_largest(sequence, 0, key=abs) 10 >>> nth_largest(sequence, 1, key=abs) -10 >>> nth_largest(sequence, 19, key=abs) 1 >>> nth_largest(sequence, 20, key=abs) 0 """ return select(sequence, n, key=key, comparator=gt)
[docs] def nth_smallest( sequence: MutableSequence[Domain], n: int, *, key: Key[Domain] | None = None, ) -> Domain: """ Returns n-th smallest element and partially sorts given sequence while searching. +------------+-------------+-----------------+------------------+ | complexity | best | average | worst | +------------+-------------+-----------------+------------------+ | time | ``O(size)`` | ``O(size)`` | ``O(size ** 2)`` | +------------+-------------+-----------------+------------------+ | memory | ``O(1)`` | ``O(log size)`` | ``O(size)`` | +------------+-------------+-----------------+------------------+ where ``size = len(sequence)``. :param sequence: sequence to search in :param n: index of the element to search for in the sequence sorted by key in ascending order (e.g. ``n = 0`` corresponds to the minimum element) :param key: single argument ordering function, if none is specified compares elements themselves :returns: n-th smallest element of the sequence >>> sequence = list(range(-10, 11)) >>> nth_smallest(sequence, 0) -10 >>> nth_smallest(sequence, 1) -9 >>> nth_smallest(sequence, 19) 9 >>> nth_smallest(sequence, 20) 10 >>> nth_smallest(sequence, 0, key=abs) 0 >>> nth_smallest(sequence, 1, key=abs) 1 >>> nth_smallest(sequence, 19, key=abs) -10 >>> nth_smallest(sequence, 20, key=abs) 10 """ return select(sequence, n, key=key, comparator=lt)
[docs] def select( sequence: MutableSequence[Domain], n: int, *, start: int = 0, stop: int | None = None, key: Key[Domain] | None = None, comparator: Comparator, ) -> Domain: """ Partially sorts given sequence and returns n-th element. +------------+-------------+-----------------+------------------+ | complexity | best | average | worst | +------------+-------------+-----------------+------------------+ | time | ``O(size)`` | ``O(size)`` | ``O(size ** 2)`` | +------------+-------------+-----------------+------------------+ | memory | ``O(1)`` | ``O(log size)`` | ``O(size)`` | +------------+-------------+-----------------+------------------+ where ``size = len(sequence)``. :param sequence: sequence to select from :param n: index of the element to select :param start: index to start selection from :param stop: index to stop selection at :param key: single argument ordering function, if none is specified compares elements themselves :param comparator: binary predicate that defines the sorting order :returns: n-th element of the sequence with slice partially sorted by key in given order >>> from operator import gt, lt >>> sequence = list(range(-10, 11)) >>> select(sequence, 0, stop=5, comparator=gt) -5 >>> select(sequence, 0, stop=5, comparator=lt) -10 >>> select(sequence, 20, start=15, comparator=lt) 10 >>> select(sequence, 20, start=15, comparator=gt) 5 >>> select(sequence, 5, start=5, stop=15, key=abs, comparator=lt) 0 >>> select(sequence, 5, start=5, stop=15, key=abs, comparator=gt) 10 """ if stop is None: stop = len(sequence) - 1 keys = sequence if key is None else _SequenceKeyView(sequence, key) while True: pivot_index = _partition(sequence, keys, start, stop, comparator) if pivot_index < n: start = pivot_index + 1 elif pivot_index > n: stop = pivot_index - 1 else: return sequence[n]
def _partition( sequence: MutableSequence[Domain], keys: Sequence[Any], start: int, stop: int, comparator: Callable[[Domain, Domain], bool], ) -> int: pivot = keys[(start + stop) // 2] while start <= stop: while comparator(keys[start], pivot): start += 1 while comparator(pivot, keys[stop]): stop -= 1 if keys[start] == keys[stop]: start += 1 if start >= stop: break sequence[start], sequence[stop] = sequence[stop], sequence[start] return stop