Source code for quickselect.floyd_rivest

"""
Based on selection algorithm by Robert W. Floyd and Ronald L. Rivest.

Reference:
    https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm
"""

from __future__ import annotations

import math
from operator import gt, lt
from typing import Any, Callable, MutableSequence, Sequence

from .core.utils import SequenceKeyView as _SequenceKeyView
from .hints import Comparator, Domain, Key


[docs] def nth_largest( sequence: MutableSequence[Domain], n: int, *, key: Key[Domain] | None = None, ) -> Domain: """ Returns n-th largest element and partially sorts given sequence while searching. +------------+-------------+-----------------+------------------+ | complexity | best | average | worst | +------------+-------------+-----------------+------------------+ | time | ``O(size)`` | ``O(size)`` | ``O(size ** 2)`` | +------------+-------------+-----------------+------------------+ | memory | ``O(1)`` | ``O(log size)`` | ``O(size)`` | +------------+-------------+-----------------+------------------+ where ``size = len(sequence)``. :param sequence: sequence to search in :param n: index of the element to search for in the sequence sorted by key in descending order (e.g. ``n = 0`` corresponds to the maximum element) :param key: single argument ordering function, if none is specified compares elements themselves :returns: n-th largest element of the sequence >>> sequence = list(range(-10, 11)) >>> nth_largest(sequence, 0) 10 >>> nth_largest(sequence, 1) 9 >>> nth_largest(sequence, 19) -9 >>> nth_largest(sequence, 20) -10 >>> nth_largest(sequence, 0, key=abs) -10 >>> nth_largest(sequence, 1, key=abs) -10 >>> nth_largest(sequence, 19, key=abs) -1 >>> nth_largest(sequence, 20, key=abs) 0 """ return select(sequence, n, key=key, comparator=gt)
[docs] def nth_smallest( sequence: MutableSequence[Domain], n: int, *, key: Key[Domain] | None = None, ) -> Domain: """ Returns n-th smallest element and partially sorts given sequence while searching. +------------+-------------+-----------------+------------------+ | complexity | best | average | worst | +------------+-------------+-----------------+------------------+ | time | ``O(size)`` | ``O(size)`` | ``O(size ** 2)`` | +------------+-------------+-----------------+------------------+ | memory | ``O(1)`` | ``O(log size)`` | ``O(size)`` | +------------+-------------+-----------------+------------------+ where ``size = len(sequence)``. :param sequence: sequence to search in :param n: index of the element to search for in the sequence sorted by key in ascending order (e.g. ``n = 0`` corresponds to the minimum element) :param key: single argument ordering function, if none is specified compares elements themselves :returns: the n-th smallest element of the sequence >>> sequence = list(range(-10, 11)) >>> nth_smallest(sequence, 0) -10 >>> nth_smallest(sequence, 1) -9 >>> nth_smallest(sequence, 19) 9 >>> nth_smallest(sequence, 20) 10 >>> nth_smallest(sequence, 0, key=abs) 0 >>> nth_smallest(sequence, 1, key=abs) -1 >>> nth_smallest(sequence, 19, key=abs) -10 >>> nth_smallest(sequence, 20, key=abs) 10 """ return select(sequence, n, key=key, comparator=lt)
[docs] def select( sequence: MutableSequence[Domain], n: int, *, start: int = 0, stop: int | None = None, key: Key[Domain] | None = None, comparator: Comparator, ) -> Domain: """ Partially sorts given sequence and returns n-th element. +------------+-------------+-----------------+------------------+ | complexity | best | average | worst | +------------+-------------+-----------------+------------------+ | time | ``O(size)`` | ``O(size)`` | ``O(size ** 2)`` | +------------+-------------+-----------------+------------------+ | memory | ``O(1)`` | ``O(log size)`` | ``O(size)`` | +------------+-------------+-----------------+------------------+ where ``size = len(sequence)``. :param sequence: sequence to select from :param n: index of the element to select :param start: index to start selection from :param stop: index to stop selection at :param key: single argument ordering function, if none is specified compares elements themselves :param comparator: binary predicate that defines the sorting order :returns: n-th element of the sequence with slice partially sorted by key in given order >>> from operator import gt, lt >>> sequence = list(range(-10, 11)) >>> select(sequence, 0, stop=5, comparator=gt) -5 >>> select(sequence, 0, stop=5, comparator=lt) -10 >>> select(sequence, 20, start=15, comparator=lt) 10 >>> select(sequence, 20, start=15, comparator=gt) 5 >>> select(sequence, 5, start=5, stop=15, key=abs, comparator=lt) 0 >>> select(sequence, 5, start=5, stop=15, key=abs, comparator=gt) 10 """ if stop is None: stop = len(sequence) - 1 keys = sequence if key is None else _SequenceKeyView(sequence, key) _presort(sequence, keys, n, start, stop, comparator) return sequence[n]
def _presort( sequence: MutableSequence[Domain], keys: Sequence[Any], n: int, start: int, stop: int, comparator: Comparator, ) -> None: candidate, pivot = sequence[n], keys[n] while start < stop: if stop - start > 600: range_size = stop - start + 1 i, s = n - start + 1, 0.5 * range_size ** (2 / 3) sigma = ( 0.5 * math.sqrt( math.log(range_size) * s * (range_size - s) / range_size ) * (-1 if i < range_size / 2 else 1) ) _presort( sequence, keys, n, max(start, math.floor(n - i * s / range_size + sigma)), min( stop, math.floor(n + (range_size - i) * s / range_size + sigma), ), comparator, ) sequence[start], sequence[n] = candidate, sequence[start] if comparator(pivot, keys[stop]): sequence[start], sequence[stop] = sequence[stop], candidate pivot_index = _partition( sequence, keys, pivot, start, stop, comparator ) if keys[start] == pivot: sequence[start], sequence[pivot_index] = ( sequence[pivot_index], sequence[start], ) else: pivot_index += 1 sequence[stop], sequence[pivot_index] = ( sequence[pivot_index], sequence[stop], ) if pivot_index <= n: start = pivot_index + 1 if pivot_index >= n: stop = pivot_index - 1 candidate, pivot = sequence[n], keys[n] def _partition( sequence: MutableSequence[Domain], keys: Sequence[Any], pivot: Domain, start: int, stop: int, comparator: Callable[[Domain, Domain], bool], ) -> int: while start < stop: sequence[start], sequence[stop] = sequence[stop], sequence[start] start += 1 stop -= 1 while comparator(keys[start], pivot): start += 1 while comparator(pivot, keys[stop]): stop -= 1 return stop