diff --git a/.gitignore b/.gitignore index c2fb624..ceb2221 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,7 @@ cython_debug/ # OS .DS_Store Thumbs.db + +# Project-specific private documentation +PUBLISHING.md +publishing.md diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..1592c40 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,245 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" + +[[package]] +name = "bitflags" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "indoc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "pythonstl" +version = "0.1.4" +dependencies = [ + "pyo3", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ed6a63f02c8539c91a8685a86f4099661ba3da017932f6ebbea6de3f0fa7c90" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8febee6 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "pythonstl" +version = "0.1.4" +edition = "2021" +description = "Rust backend extension for pythonstl data structures" + +[lib] +name = "_rust" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.21.0", features = ["extension-module"] } diff --git a/README.md b/README.md index 8095e4c..fc6a908 100644 --- a/README.md +++ b/README.md @@ -320,25 +320,43 @@ Full Python integration while maintaining STL compatibility: - Copy protocol support - Maintains backward compatibility -## Benchmarks +## 📊 Performance Benchmarks -PythonSTL provides benchmarks comparing performance against Python built-ins: +PythonSTL includes a compiled Rust backend (built with PyO3 and Maturin) for high-performance operations, alongside pure-Python fallbacks. Below are the actual performance comparison results against pure-Python and native C++ (compiled with `g++ -O3`). -```bash -python benchmarks/benchmark_stack.py -python benchmarks/benchmark_vector.py -python benchmarks/benchmark_map.py -``` +### 1. Containers Performance (50,000 Operations) + +| Container Class | Pure Python | Python + Rust | Speedup Status | Design / Algorithmic Trade-off | +| :--- | :--- | :--- | :--- | :--- | +| **Stack** | 0.2470s | 0.1604s | **1.54x faster** | Linear stack operations. Limited by FFI call overhead. | +| **Queue** | 0.2547s | 0.1946s | **1.31x faster** | FIFO operations. Limited by FFI call overhead. | +| **Vector** | 0.0050s | 0.0045s | **1.10x faster** | Push_back & random access indices. Limited by FFI. | +| **Set** | 0.0337s | 0.1844s | *0.18x faster* | **Sorted Set vs Unordered Hash Set** (replicates C++ B-Tree structure) | +| **Map** | 0.0421s | 0.2104s | *0.20x faster* | **Sorted Map vs Unordered Hash Map** (replicates C++ B-Tree structure) | +| **Priority Queue**| 0.0584s | 0.0900s | *0.65x faster* | Custom binary heap vs. C-optimized `heapq` module. | + +* **Sorted Trees vs. Hash Tables**: Python's native `set` and `dict` are highly optimized $O(1)$ hash tables written in C. PythonSTL sets/maps replicate C++'s `std::set`/`std::map` using sorted trees (`BTreeSet`/`BTreeMap`), which run in $O(\log N)$ and sort keys. +* **FFI overhead**: Storing arbitrary Python objects in Rust requires acquiring the GIL and calling back into the Python VM for comparisons, creating high FFI boundaries. + +### 2. Algorithms Suite + +| Algorithm Name | Pure Python (Middle Pivot) | Python + Rust | Pure C++ (O3) | Rust Speedup | Design & FFI Insights | +| :--- | :--- | :--- | :--- | :--- | :--- | +| **next_permutation** | 0.3158s | 0.2530s | 0.0020s | **1.2x** | Lexicographical rearrangement. Limited by FFI conversions. | +| **nth_element** | 0.0068s | 0.0047s | 0.0000s | **1.5x** | Quickselect median find. (Previously **70.85s** before optimization). | +| **partition** | 0.0193s | 0.0197s | 0.0000s | **1.0x** | Lambda-predicate partitioning. Dominated by FFI callback overhead. | + +* **Algorithmic Pivot Vulnerabilities**: A naive Lomuto partition (`pivot = arr[right]`) causes $O(N^2)$ worst-case time on already-sorted or reversed lists (taking **70.85s**). By switching PythonSTL to a middle-pivot (`arr[mid]`), we restore $O(N)$ average time (**0.0068s**). + +### 3. Binary Search (5,000 Queries on 1,000,000 elements) -**Expected Overhead:** 1.1x - 1.5x compared to native Python structures +| Search Mode / Comparator | Pure Python | Python + Rust | Pure C++ (O3) | Rust Speedup | Systems & Design Insights | +| :--- | :--- | :--- | :--- | :--- | :--- | +| **Standard (`<` comparison)** | 0.0214s | 0.0028s | 0.0000s | **7.5x** | Preserves $O(\log N)$ via direct list indexing. | +| **Custom Comparator (lambda)**| 0.0251s | 0.0074s | N/A | **3.4x** | Overcomes Python loop overhead despite FFI callbacks. | -The facade pattern adds minimal overhead while providing: -- STL-style API -- Better error messages -- Bounds checking -- Type safety +* **Direct Indexing**: Instead of extracting/copying the entire list (an $O(N)$ operation), the Rust backend uses direct GIL-bound indexing (`arr.get_item(mid)`), maintaining the strict $O(\log N)$ search complexity. -See `benchmarks/README.md` for detailed analysis. ## Testing @@ -401,4 +419,4 @@ Contributions are welcome! Please: - GitHub: [@AnshMNSoni](https://github.com/AnshMNSoni) - Issues: [GitHub Issues](https://github.com/AnshMNSoni/PythonSTL/issues) -**PythonSTL v0.1.1** - Bringing C++ STL elegance to Python +**PythonSTL v1.1.4** - Bringing C++ STL elegance to Python diff --git a/benchmarks/benchmark_algorithms.py b/benchmarks/benchmark_algorithms.py new file mode 100644 index 0000000..802e926 --- /dev/null +++ b/benchmarks/benchmark_algorithms.py @@ -0,0 +1,140 @@ +import time +import subprocess +import os +import sys +import gc +from pathlib import Path + +# Add project root to path to run directly from development folder +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from pythonstl import next_permutation, nth_element, partition, stl_set +from pythonstl.facade.algorithms import RUST_AVAILABLE + +def run_py_permutation(): + arr = [1, 2, 3, 4, 5, 6, 7, 8, 9] + while next_permutation(arr, use_rust=False): + pass + +def run_rust_permutation(): + arr = [1, 2, 3, 4, 5, 6, 7, 8, 9] + while next_permutation(arr, use_rust=True): + pass + +def run_py_nth_element(): + arr = list(range(50000, 0, -1)) + nth_element(arr, 25000, use_rust=False) + +def run_rust_nth_element(): + arr = list(range(50000, 0, -1)) + nth_element(arr, 25000, use_rust=True) + +def run_py_partition(): + arr = list(range(100000)) + partition(arr, lambda x: x % 2 == 0, use_rust=False) + +def run_rust_partition(): + arr = list(range(100000)) + partition(arr, lambda x: x % 2 == 0, use_rust=True) + +# ----------------- FFI/Execution Framework ----------------- + +def compile_cpp(): + bench_dir = Path(__file__).parent + cpp_source = bench_dir / "benchmark_native.cpp" + cpp_exe = bench_dir / ("benchmark_native.exe" if os.name == 'nt' else "benchmark_native") + + try: + subprocess.run( + ["g++", "-O3", str(cpp_source), "-o", str(cpp_exe)], + check=True, + capture_output=True + ) + return cpp_exe + except Exception: + return None + +def run_cpp_benchmark(cpp_exe, arg): + try: + result = subprocess.run([str(cpp_exe), arg], check=True, capture_output=True, text=True) + return float(result.stdout.strip()) + except Exception: + return None + +def run_test(name, py_func, rust_func): + print(f"Benchmarking {name}...") + + gc.collect() + start = time.perf_counter() + py_func() + py_t = time.perf_counter() - start + + rust_t = None + if RUST_AVAILABLE: + gc.collect() + start = time.perf_counter() + rust_func() + rust_t = time.perf_counter() - start + + return py_t, rust_t + +def main(): + print("=============================================================") + print(" PythonSTL Performance Benchmark: Algorithms ") + print("=============================================================\n") + + cpp_exe = compile_cpp() + + # Run tests + py_perm, rust_perm = run_test("next_permutation (9 elements, 362,880 cycles)", run_py_permutation, run_rust_permutation) + cpp_perm = run_cpp_benchmark(cpp_exe, "next_permutation") if cpp_exe else None + + py_nth, rust_nth = run_test("nth_element (50,000 reversed items, find median)", run_py_nth_element, run_rust_nth_element) + cpp_nth = run_cpp_benchmark(cpp_exe, "nth_element") if cpp_exe else None + + py_part, rust_part = run_test("partition (100,000 items, evens/odds)", run_py_partition, run_rust_partition) + cpp_part = run_cpp_benchmark(cpp_exe, "partition") if cpp_exe else None + + # Cleanup compiled binary + if cpp_exe and cpp_exe.exists(): + try: + cpp_exe.unlink() + except Exception: + pass + + print("\n" + "=" * 70) + print(" ALGORITHMS PERFORMANCE TABLE ") + print("=" * 70) + print(f"{'Algorithm Name':<22} | {'Pure Python':<12} | {'Python + Rust':<15} | {'Pure C++ (O3)':<15}") + print("-" * 70) + + def format_row(name, py_t, rust_t, cpp_t): + py_str = f"{py_t:.4f}s" + + if rust_t is not None: + if rust_t > 0: + rust_speedup = py_t / rust_t + rust_str = f"{rust_t:.4f}s ({rust_speedup:.1f}x)" + else: + rust_str = f"{rust_t:.4f}s (>1000x)" + else: + rust_str = "N/A" + + if cpp_t is not None: + if cpp_t > 0: + cpp_speedup = py_t / cpp_t + cpp_str = f"{cpp_t:.4f}s ({cpp_speedup:.1f}x)" + else: + cpp_str = f"{cpp_t:.4f}s (>1000x)" + else: + cpp_str = "N/A" + + print(f"{name:<22} | {py_str:<12} | {rust_str:<15} | {cpp_str:<15}") + + format_row("next_permutation", py_perm, rust_perm, cpp_perm) + format_row("nth_element", py_nth, rust_nth, cpp_nth) + format_row("partition", py_part, rust_part, cpp_part) + print("=============================================================") + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_all_structures.py b/benchmarks/benchmark_all_structures.py new file mode 100644 index 0000000..ab4df55 --- /dev/null +++ b/benchmarks/benchmark_all_structures.py @@ -0,0 +1,221 @@ +import time +import sys +import gc +from pathlib import Path + +# Add project root to path to run directly from development folder +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from pythonstl import stack, queue, vector, stl_set, stl_map, priority_queue +from pythonstl.facade.stack import RUST_AVAILABLE as STACK_RUST_AVAILABLE +from pythonstl.facade.queue import RUST_AVAILABLE as QUEUE_RUST_AVAILABLE +from pythonstl.facade.vector import RUST_AVAILABLE as VECTOR_RUST_AVAILABLE +from pythonstl.facade.set import RUST_AVAILABLE as SET_RUST_AVAILABLE +from pythonstl.facade.map import RUST_AVAILABLE as MAP_RUST_AVAILABLE +from pythonstl.facade.priority_queue import RUST_AVAILABLE as PQ_RUST_AVAILABLE + +def run_benchmark(name, py_func, rust_func, has_rust): + print(f"Benchmarking {name}...") + + # Run Python benchmark + gc.collect() + start = time.perf_counter() + py_func() + py_time = time.perf_counter() - start + + rust_time = None + if has_rust: + gc.collect() + start = time.perf_counter() + rust_func() + rust_time = time.perf_counter() - start + + return py_time, rust_time + +# ----------------- Benchmark Workloads ----------------- + +def bench_stack_py(): + s = stack(use_rust=False) + for i in range(500000): + s.push(i) + for _ in range(500000): + s.pop() + +def bench_stack_rust(): + s = stack(use_rust=True) + for i in range(500000): + s.push(i) + for _ in range(500000): + s.pop() + +def bench_queue_py(): + q = queue(use_rust=False) + for i in range(500000): + q.push(i) + for _ in range(500000): + q.pop() + +def bench_queue_rust(): + q = queue(use_rust=True) + for i in range(500000): + q.push(i) + for _ in range(500000): + q.pop() + +def bench_vector_py(): + v = vector(use_rust=False) + # 1. Push back 10,000 items + for i in range(10000): + v.push_back(i) + # 2. Access via at() + for i in range(10000): + _ = v.at(i) + # 3. In-place insertions + for i in range(100): + v.insert(5000, i) + +def bench_vector_rust(): + v = vector(use_rust=True) + # 1. Push back 10,000 items + for i in range(10000): + v.push_back(i) + # 2. Access via at() + for i in range(10000): + _ = v.at(i) + # 3. In-place insertions + for i in range(100): + v.insert(5000, i) + +def bench_set_py(): + s = stl_set(use_rust=False) + # 1. Insertions (50,000) + for i in range(50000): + s.insert(i) + # 2. Lookup/Find (50,000) + for i in range(50000): + _ = s.find(i) + # 3. Erasures (50,000) + for i in range(50000): + s.erase(i) + +def bench_set_rust(): + s = stl_set(use_rust=True) + # 1. Insertions (50,000) + for i in range(50000): + s.insert(i) + # 2. Lookup/Find (50,000) + for i in range(50000): + _ = s.find(i) + # 3. Erasures (50,000) + for i in range(50000): + s.erase(i) + +def bench_map_py(): + m = stl_map(use_rust=False) + # 1. Insertions (50,000) + for i in range(50000): + m.insert(i, i * 2) + # 2. Lookup/Find (50,000) + for i in range(50000): + _ = m.find(i) + # 3. Access via at() + for i in range(50000): + _ = m.at(i) + # 4. Erasures (50,000) + for i in range(50000): + m.erase(i) + +def bench_map_rust(): + m = stl_map(use_rust=True) + # 1. Insertions (50,000) + for i in range(50000): + m.insert(i, i * 2) + # 2. Lookup/Find (50,000) + for i in range(50000): + _ = m.find(i) + # 3. Access via at() + for i in range(50000): + _ = m.at(i) + # 4. Erasures (50,000) + for i in range(50000): + m.erase(i) + +def bench_priority_queue_py(): + pq = priority_queue(use_rust=False) + # 1. Pushes (50,000) + for i in range(50000): + pq.push(i) + # 2. Pops (50,000) + for _ in range(50000): + _ = pq.top() + pq.pop() + +def bench_priority_queue_rust(): + pq = priority_queue(use_rust=True) + # 1. Pushes (50,000) + for i in range(50000): + pq.push(i) + # 2. Pops (50,000) + for _ in range(50000): + _ = pq.top() + pq.pop() + + +# ----------------- Main Execution ----------------- + +def main(): + print("=============================================================") + print(" PythonSTL Comprehensive Container Benchmark Suite ") + print("=============================================================\n") + + results = {} + + # 1. Stack + py_t, rust_t = run_benchmark("Stack (500,000 cycles)", bench_stack_py, bench_stack_rust, STACK_RUST_AVAILABLE) + results["Stack"] = (py_t, rust_t, STACK_RUST_AVAILABLE) + + # 2. Queue + py_t, rust_t = run_benchmark("Queue (500,000 cycles)", bench_queue_py, bench_queue_rust, QUEUE_RUST_AVAILABLE) + results["Queue"] = (py_t, rust_t, QUEUE_RUST_AVAILABLE) + + # 3. Vector + py_t, rust_t = run_benchmark("Vector (10,000 push/access + 100 inserts)", bench_vector_py, bench_vector_rust, VECTOR_RUST_AVAILABLE) + results["Vector"] = (py_t, rust_t, VECTOR_RUST_AVAILABLE) + + # 4. Set + py_t, rust_t = run_benchmark("Set (50,000 inserts/finds/erases)", bench_set_py, bench_set_rust, SET_RUST_AVAILABLE) + results["Set"] = (py_t, rust_t, SET_RUST_AVAILABLE) + + # 5. Map + py_t, rust_t = run_benchmark("Map (50,000 inserts/finds/ats/erases)", bench_map_py, bench_map_rust, MAP_RUST_AVAILABLE) + results["Map"] = (py_t, rust_t, MAP_RUST_AVAILABLE) + + # 6. Priority Queue + py_t, rust_t = run_benchmark("Priority Queue (50,000 push/pops)", bench_priority_queue_py, bench_priority_queue_rust, PQ_RUST_AVAILABLE) + results["Priority Queue"] = (py_t, rust_t, PQ_RUST_AVAILABLE) + + print("\n" + "=" * 70) + print(" PERFORMANCE SUMMARY TABLE ") + print("=" * 70) + print(f"{'Container Class':<18} | {'Pure Python':<12} | {'Python + Rust':<15} | {'Speedup Status':<18}") + print("-" * 70) + + for container, (py_time, rust_time, is_rust) in results.items(): + py_str = f"{py_time:.4f}s" + if is_rust and rust_time is not None: + rust_str = f"{rust_time:.4f}s" + speedup = py_time / rust_time + status = f"{speedup:.2f}x faster" + else: + rust_str = "N/A" + status = "Pure Py Fallback" + + print(f"{container:<18} | {py_str:<12} | {rust_str:<15} | {status:<18}") + + print("=============================================================") + print("Note: Containers marked 'Pure Py Fallback' will run using the") + print("original python backends until their Rust cores are built.") + print("=============================================================") + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_binary_search.py b/benchmarks/benchmark_binary_search.py new file mode 100644 index 0000000..3794766 --- /dev/null +++ b/benchmarks/benchmark_binary_search.py @@ -0,0 +1,147 @@ +import time +import subprocess +import os +import sys +import gc +from pathlib import Path + +# Add project root to path to run directly from development folder +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from pythonstl import lower_bound +from pythonstl.facade.algorithms import RUST_AVAILABLE + +def run_py_binary_search(arr, targets): + sum_indices = 0 + for target in targets: + sum_indices += lower_bound(arr, target, use_rust=False) + return sum_indices + +def run_rust_binary_search(arr, targets): + sum_indices = 0 + for target in targets: + sum_indices += lower_bound(arr, target, use_rust=True) + return sum_indices + +def run_py_comp_binary_search(arr, targets, comp): + sum_indices = 0 + for target in targets: + sum_indices += lower_bound(arr, target, comp=comp, use_rust=False) + return sum_indices + +def run_rust_comp_binary_search(arr, targets, comp): + sum_indices = 0 + for target in targets: + sum_indices += lower_bound(arr, target, comp=comp, use_rust=True) + return sum_indices + +# ----------------- FFI/Execution Framework ----------------- + +def compile_cpp(): + bench_dir = Path(__file__).parent + cpp_source = bench_dir / "benchmark_native.cpp" + cpp_exe = bench_dir / ("benchmark_native.exe" if os.name == 'nt' else "benchmark_native") + + try: + subprocess.run( + ["g++", "-O3", str(cpp_source), "-o", str(cpp_exe)], + check=True, + capture_output=True + ) + return cpp_exe + except Exception: + return None + +def run_cpp_benchmark(cpp_exe, arg): + try: + result = subprocess.run([str(cpp_exe), arg], check=True, capture_output=True, text=True) + return float(result.stdout.strip()) + except Exception: + return None + +def main(): + print("=============================================================") + print(" PythonSTL Performance Benchmark: Binary Search ") + print("=============================================================\n") + + # 1. Setup sorted array of 1,000,000 elements + print("Generating 1,000,000 elements sorted array...") + arr = [i * 2 for i in range(1000000)] + targets = [q * 3 for q in range(5000)] + + cpp_exe = compile_cpp() + + # Standard comparisons (Pure Py vs Rust vs C++) + print("Running standard binary search (5,000 queries)...") + gc.collect() + start = time.perf_counter() + run_py_binary_search(arr, targets) + py_t = time.perf_counter() - start + + rust_t = None + if RUST_AVAILABLE: + gc.collect() + start = time.perf_counter() + run_rust_binary_search(arr, targets) + rust_t = time.perf_counter() - start + + cpp_t = run_cpp_benchmark(cpp_exe, "binary_search") if cpp_exe else None + + # Custom comparator comparisons (how slow is FFI callback?) + print("Running custom comparator binary search (5,000 queries)...") + comp = lambda a, b: a < b + + gc.collect() + start = time.perf_counter() + run_py_comp_binary_search(arr, targets, comp) + py_comp_t = time.perf_counter() - start + + rust_comp_t = None + if RUST_AVAILABLE: + gc.collect() + start = time.perf_counter() + run_rust_comp_binary_search(arr, targets, comp) + rust_comp_t = time.perf_counter() - start + + # Cleanup compiled binary + if cpp_exe and cpp_exe.exists(): + try: + cpp_exe.unlink() + except Exception: + pass + + print("\n" + "=" * 70) + print(" BINARY SEARCH PERFORMANCE TABLE ") + print("=" * 70) + print(f"{'Search Mode / Comparator':<26} | {'Pure Python':<12} | {'Python + Rust':<15} | {'Pure C++ (O3)':<15}") + print("-" * 70) + + def format_row(name, py_t, rust_t, cpp_t): + py_str = f"{py_t:.4f}s" + + if rust_t is not None: + if rust_t > 0: + rust_speedup = py_t / rust_t + rust_str = f"{rust_t:.4f}s ({rust_speedup:.1f}x)" + else: + rust_str = f"{rust_t:.4f}s (>1000x)" + else: + rust_str = "N/A" + + if cpp_t is not None: + if cpp_t > 0: + cpp_speedup = py_t / cpp_t + cpp_str = f"{cpp_t:.4f}s ({cpp_speedup:.1f}x)" + else: + cpp_str = f"{cpp_t:.4f}s (>1000x)" + else: + cpp_str = "N/A" + + print(f"{name:<26} | {py_str:<12} | {rust_str:<15} | {cpp_str:<15}") + + format_row("Standard (< comparison)", py_t, rust_t, cpp_t) + format_row("Custom Comparator (lambda)", py_comp_t, rust_comp_t, None) + print("=============================================================") + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_native.cpp b/benchmarks/benchmark_native.cpp new file mode 100644 index 0000000..ddf018a --- /dev/null +++ b/benchmarks/benchmark_native.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include + +void run_stack() { + auto start = std::chrono::high_resolution_clock::now(); + std::stack s; + for (int i = 0; i < 1000000; ++i) { + s.push(i); + } + for (int i = 0; i < 1000000; ++i) { + s.pop(); + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cout << diff.count() << std::endl; +} + +void run_sort() { + std::vector arr(10000); + for (int i = 0; i < 10000; ++i) { + arr[i] = 10000 - i; + } + + auto start = std::chrono::high_resolution_clock::now(); + int n = arr.size(); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n - 1 - i; ++j) { + if (arr[j] > arr[j + 1]) { + std::swap(arr[j], arr[j + 1]); + } + } + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cout << diff.count() << std::endl; +} + +void run_next_permutation() { + auto start = std::chrono::high_resolution_clock::now(); + std::vector arr = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + while (std::next_permutation(arr.begin(), arr.end())) { + // do nothing + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cout << diff.count() << std::endl; + // Prevent compiler optimization by printing array element to stderr + std::cerr << "next_permutation check: " << arr[0] << std::endl; +} + +void run_nth_element() { + std::vector arr(50000); + for (int i = 0; i < 50000; ++i) { + arr[i] = 50000 - i; + } + auto start = std::chrono::high_resolution_clock::now(); + std::nth_element(arr.begin(), arr.begin() + 25000, arr.end()); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cout << diff.count() << std::endl; + // Prevent compiler optimization by printing array element to stderr + std::cerr << "nth_element check: " << arr[25000] << std::endl; +} + +void run_partition() { + std::vector arr(100000); + for (int i = 0; i < 100000; ++i) { + arr[i] = i; + } + auto start = std::chrono::high_resolution_clock::now(); + std::partition(arr.begin(), arr.end(), [](int x) { return x % 2 == 0; }); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cout << diff.count() << std::endl; + // Prevent compiler optimization by printing array element to stderr + std::cerr << "partition check: " << arr[0] << std::endl; +} + +void run_binary_search() { + std::vector arr(1000000); + for (int i = 0; i < 1000000; ++i) { + arr[i] = i * 2; + } + auto start = std::chrono::high_resolution_clock::now(); + long long sum_indices = 0; + for (int q = 0; q < 5000; ++q) { + int target = q * 3; + auto it = std::lower_bound(arr.begin(), arr.end(), target); + sum_indices += std::distance(arr.begin(), it); + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + std::cout << diff.count() << std::endl; + // Prevent compiler optimization by printing sum of indices to stderr + std::cerr << "binary_search check: " << sum_indices << std::endl; +} + +int main(int argc, char* argv[]) { + std::string mode = "stack"; + if (argc > 1) { + mode = argv[1]; + } + if (mode == "sort") { + run_sort(); + } else if (mode == "next_permutation") { + run_next_permutation(); + } else if (mode == "nth_element") { + run_nth_element(); + } else if (mode == "partition") { + run_partition(); + } else if (mode == "binary_search") { + run_binary_search(); + } else { + run_stack(); + } + return 0; +} diff --git a/benchmarks/benchmark_rust_vs_py.py b/benchmarks/benchmark_rust_vs_py.py new file mode 100644 index 0000000..7a99969 --- /dev/null +++ b/benchmarks/benchmark_rust_vs_py.py @@ -0,0 +1,193 @@ +import time +import subprocess +import os +import sys +from pathlib import Path + +# Add project root to path to run directly from development folder +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from pythonstl.facade.stack import stack, RUST_AVAILABLE + +# Try importing the bubble_sort function from the compiled Rust library +try: + from pythonstl._rust import bubble_sort as rust_bubble_sort + HAS_RUST_SORT = True +except ImportError: + HAS_RUST_SORT = False + +def run_python_stack_benchmark(): + print("Running Pure Python Stack benchmark...") + s = stack(use_rust=False) + + start = time.perf_counter() + for i in range(1000000): + s.push(i) + for _ in range(1000000): + s.pop() + end = time.perf_counter() + + return end - start + +def run_rust_stack_benchmark(): + if not RUST_AVAILABLE: + print("Python + Rust Stack is not available (Rust binary not compiled). Skipping...") + return None + + print("Running Python + Rust Stack benchmark...") + s = stack(use_rust=True) + + start = time.perf_counter() + for i in range(1000000): + s.push(i) + for _ in range(1000000): + s.pop() + end = time.perf_counter() + + return end - start + +# ----------------- Sorting Benchmarks ----------------- + +def python_bubble_sort(arr): + arr = list(arr) + n = len(arr) + for i in range(n): + for j in range(0, n-i-1): + if arr[j] > arr[j+1]: + arr[j], arr[j+1] = arr[j+1], arr[j] + return arr + +def run_python_sort_benchmark(arr): + print("Running Pure Python Bubble Sort (10,000 items)...") + start = time.perf_counter() + python_bubble_sort(arr) + end = time.perf_counter() + return end - start + +def run_rust_sort_benchmark(arr): + if not HAS_RUST_SORT: + print("Rust bubble_sort is not compiled. Skipping...") + return None + print("Running Python + Rust Bubble Sort (10,000 items)...") + start = time.perf_counter() + rust_bubble_sort(list(arr)) + end = time.perf_counter() + return end - start + +# ----------------- Native C++ Executions ----------------- + +def compile_cpp(): + bench_dir = Path(__file__).parent + cpp_source = bench_dir / "benchmark_native.cpp" + cpp_exe = bench_dir / ("benchmark_native.exe" if os.name == 'nt' else "benchmark_native") + + print("Checking C++ compiler...") + try: + # Compile the C++ program with optimizations enabled (-O3) + print("Compiling native C++ benchmark (g++ -O3)...") + subprocess.run( + ["g++", "-O3", str(cpp_source), "-o", str(cpp_exe)], + check=True, + capture_output=True + ) + return cpp_exe + except (subprocess.CalledProcessError, FileNotFoundError) as e: + print(f"Failed to compile C++ benchmark: {e}. Skipping C++ baseline.") + return None + +def run_cpp_benchmark(cpp_exe, arg): + print(f"Running compiled C++ benchmark ({arg})...") + try: + result = subprocess.run([str(cpp_exe), arg], check=True, capture_output=True, text=True) + cpp_time = float(result.stdout.strip()) + return cpp_time + except Exception as e: + print(f"Failed to run C++ benchmark ({arg}): {e}") + return None + +def main(): + print("============================================================") + print(" PythonSTL Performance Benchmark Suit ") + print("============================================================") + + # Compilation + cpp_exe = compile_cpp() + print() + + # --- BENCHMARK 1: Stack --- + print("------------------------------------------------------------") + print(" BENCHMARK 1: Stack (1,000,000 push/pop cycles)") + print(" Note: High number of Python-Rust boundary crossings") + print("------------------------------------------------------------") + + py_stack = run_python_stack_benchmark() + print(f"Pure Python Stack: {py_stack:.4f} seconds") + + rust_stack = run_rust_stack_benchmark() + if rust_stack is not None: + print(f"Python + Rust Stack: {rust_stack:.4f} seconds") + + cpp_stack = None + if cpp_exe is not None: + cpp_stack = run_cpp_benchmark(cpp_exe, "stack") + if cpp_stack is not None: + print(f"Pure C++ Stack: {cpp_stack:.4f} seconds") + print() + + # --- BENCHMARK 2: Bubble Sort --- + print("------------------------------------------------------------") + print(" BENCHMARK 2: Bubble Sort (10,000 reversed items)") + print(" Note: Single boundary crossing, heavy computational load") + print("------------------------------------------------------------") + + sort_data = list(range(10000, 0, -1)) + + py_sort = run_python_sort_benchmark(sort_data) + print(f"Pure Python Sort: {py_sort:.4f} seconds") + + rust_sort_time = run_rust_sort_benchmark(sort_data) + if rust_sort_time is not None: + print(f"Python + Rust Sort: {rust_sort_time:.4f} seconds") + + cpp_sort = None + if cpp_exe is not None: + cpp_sort = run_cpp_benchmark(cpp_exe, "sort") + if cpp_sort is not None: + print(f"Pure C++ Sort: {cpp_sort:.4f} seconds") + + # Cleanup compiled C++ binary + if cpp_exe is not None and cpp_exe.exists(): + try: + cpp_exe.unlink() + except Exception: + pass + + print("\n" + "=" * 60) + print(" SUMMARY TABLES & METRICS ") + print("=" * 60) + + print("\nTABLE 1: STACK OPERATIONS (1,000,000 Cycles)") + print(f"{'Implementation':<20} | {'Time (Seconds)':<15} | {'Speedup vs. Python':<20}") + print("-" * 62) + print(f"{'Pure Python':<20} | {py_stack:.4f}s{''*8} | {'1.0x (Baseline)':<20}") + if rust_stack is not None: + rust_speedup = py_stack / rust_stack + print(f"{'Python + Rust':<20} | {rust_stack:.4f}s{''*8} | {f'{rust_speedup:.2f}x faster':<20}") + if cpp_stack is not None: + cpp_speedup = py_stack / cpp_stack + print(f"{'Pure C++ (O3)':<20} | {cpp_stack:.4f}s{''*8} | {f'{cpp_speedup:.2f}x faster':<20}") + + print("\nTABLE 2: BUBBLE SORT ALGORITHM (10,000 Elements)") + print(f"{'Implementation':<20} | {'Time (Seconds)':<15} | {'Speedup vs. Python':<20}") + print("-" * 62) + print(f"{'Pure Python':<20} | {py_sort:.4f}s{''*8} | {'1.0x (Baseline)':<20}") + if rust_sort_time is not None: + rust_speedup = py_sort / rust_sort_time + print(f"{'Python + Rust':<20} | {rust_sort_time:.4f}s{''*8} | {f'{rust_speedup:.2f}x faster (Maturin)':<20}") + if cpp_sort is not None: + cpp_speedup = py_sort / cpp_sort + print(f"{'Pure C++ (O3)':<20} | {cpp_sort:.4f}s{''*8} | {f'{cpp_speedup:.2f}x faster':<20}") + print("=" * 60) + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 027d424..236c093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,14 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["maturin>=1.5,<2.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "." +module-name = "pythonstl._rust" [project] name = "pythonstl" -version = "0.1.4" +version = "1.1.4" description = "C++ STL-style containers implemented in Python using the Facade Design Pattern" readme = "README.md" authors = [ diff --git a/pythonstl/__init__.py b/pythonstl/__init__.py index 77d4b2f..54e07fd 100644 --- a/pythonstl/__init__.py +++ b/pythonstl/__init__.py @@ -8,7 +8,7 @@ data structures while hiding implementation details from users. """ -__version__ = "0.1.1" +__version__ = "1.1.4" __author__ = "PySTL Contributors" from pythonstl.facade.stack import stack @@ -17,6 +17,16 @@ from pythonstl.facade.set import stl_set from pythonstl.facade.map import stl_map from pythonstl.facade.priority_queue import priority_queue +from pythonstl.facade.algorithms import ( + next_permutation, + prev_permutation, + nth_element, + partition, + lower_bound, + upper_bound, + binary_search, + equal_range +) # Also export exceptions for user error handling from pythonstl.core.exceptions import ( @@ -34,6 +44,15 @@ 'stl_set', 'stl_map', 'priority_queue', + # Algorithms + 'next_permutation', + 'prev_permutation', + 'nth_element', + 'partition', + 'lower_bound', + 'upper_bound', + 'binary_search', + 'equal_range', # Exceptions 'PySTLException', 'EmptyContainerError', diff --git a/pythonstl/_rust.pdb b/pythonstl/_rust.pdb new file mode 100644 index 0000000..d38dfdd Binary files /dev/null and b/pythonstl/_rust.pdb differ diff --git a/pythonstl/facade/algorithms.py b/pythonstl/facade/algorithms.py new file mode 100644 index 0000000..962f4f9 --- /dev/null +++ b/pythonstl/facade/algorithms.py @@ -0,0 +1,330 @@ +""" +C++ STL Algorithms Suite. + +This module provides replicas of standard C++ algorithms from +with dynamic Rust backend loading and in-place list mutation. +""" + +from typing import Callable, Any + +try: + from pythonstl._rust import next_permutation as _rust_next_permutation + from pythonstl._rust import prev_permutation as _rust_prev_permutation + from pythonstl._rust import nth_element as _rust_nth_element + from pythonstl._rust import partition as _rust_partition + from pythonstl._rust import lower_bound as _rust_lower_bound + from pythonstl._rust import upper_bound as _rust_upper_bound + from pythonstl._rust import binary_search as _rust_binary_search + from pythonstl._rust import equal_range as _rust_equal_range + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + + +# ----------------- Pure-Python Fallbacks ----------------- + +def _py_next_permutation(arr: list) -> bool: + n = len(arr) + if n <= 1: + return False + + i = n - 2 + while i >= 0 and arr[i] >= arr[i + 1]: + i -= 1 + + if i < 0: + arr.reverse() + return False + + j = n - 1 + while arr[j] <= arr[i]: + j -= 1 + + arr[i], arr[j] = arr[j], arr[i] + arr[i + 1:] = reversed(arr[i + 1:]) + return True + + +def _py_prev_permutation(arr: list) -> bool: + n = len(arr) + if n <= 1: + return False + + i = n - 2 + while i >= 0 and arr[i] <= arr[i + 1]: + i -= 1 + + if i < 0: + arr.reverse() + return False + + j = n - 1 + while arr[j] >= arr[i]: + j -= 1 + + arr[i], arr[j] = arr[j], arr[i] + arr[i + 1:] = reversed(arr[i + 1:]) + return True + + +def _py_nth_element(arr: list, nth: int) -> None: + n = len(arr) + if nth < 0 or nth >= n: + return + + left = 0 + right = n - 1 + while left < right: + mid = left + (right - left) // 2 + arr[mid], arr[right] = arr[right], arr[mid] + pivot = arr[right] + i = left + for j in range(left, right): + if arr[j] < pivot: + arr[i], arr[j] = arr[j], arr[i] + i += 1 + arr[i], arr[right] = arr[right], arr[i] + + pivot_idx = i + if pivot_idx == nth: + return + elif pivot_idx > nth: + right = pivot_idx - 1 + else: + left = pivot_idx + 1 + + +def _py_partition(arr: list, predicate: Callable[[Any], bool]) -> int: + i = 0 + for j in range(len(arr)): + if predicate(arr[j]): + arr[i], arr[j] = arr[j], arr[i] + i += 1 + return i + + +def _py_lower_bound(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None) -> int: + left = 0 + right = len(arr) + while left < right: + mid = left + (right - left) // 2 + mid_val = arr[mid] + is_less = comp(mid_val, val) if comp else (mid_val < val) + if is_less: + left = mid + 1 + else: + right = mid + return left + + +def _py_upper_bound(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None) -> int: + left = 0 + right = len(arr) + while left < right: + mid = left + (right - left) // 2 + mid_val = arr[mid] + is_less = comp(val, mid_val) if comp else (val < mid_val) + if is_less: + right = mid + else: + left = mid + 1 + return left + + +def _py_binary_search(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None) -> bool: + if not arr: + return False + idx = _py_lower_bound(arr, val, comp) + if idx < len(arr): + elem = arr[idx] + if comp: + return not comp(elem, val) and not comp(val, elem) + return elem == val + return False + + +def _py_equal_range(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None) -> tuple[int, int]: + return _py_lower_bound(arr, val, comp), _py_upper_bound(arr, val, comp) + + +# ----------------- Public API Interfaces ----------------- + +def next_permutation(arr: list, use_rust: bool = True) -> bool: + """ + Rearranges elements in-place to the next lexicographically greater permutation. + + If the next permutation exists, rearranges elements and returns True. + Otherwise, reverses the array to the smallest ascending order and returns False. + + Args: + arr: The list to modify in-place. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + True if next permutation exists, False otherwise. + + Time Complexity: + O(n) where n is len(arr) + """ + if use_rust and RUST_AVAILABLE: + return _rust_next_permutation(arr) + return _py_next_permutation(arr) + + +def prev_permutation(arr: list, use_rust: bool = True) -> bool: + """ + Rearranges elements in-place to the next lexicographically smaller permutation. + + If the previous permutation exists, rearranges elements and returns True. + Otherwise, reverses the array to the largest descending order and returns False. + + Args: + arr: The list to modify in-place. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + True if prev permutation exists, False otherwise. + + Time Complexity: + O(n) where n is len(arr) + """ + if use_rust and RUST_AVAILABLE: + return _rust_prev_permutation(arr) + return _py_prev_permutation(arr) + + +def nth_element(arr: list, nth: int, use_rust: bool = True) -> None: + """ + Partitions the list in-place so that the element at index `nth` is the one + that would be there if the list were completely sorted. + + All elements preceding `nth` are partitioned to be less than or equal to `nth`. + All elements succeeding `nth` are partitioned to be greater than or equal to `nth`. + Does not guarantee sorted order of the surrounding elements. + + Args: + arr: The list to modify in-place. + nth: The index that should contain the sorted element. + use_rust: Whether to use the compiled Rust backend (default: True). + + Time Complexity: + O(n) average case + """ + if use_rust and RUST_AVAILABLE: + _rust_nth_element(arr, nth) + else: + _py_nth_element(arr, nth) + + +def partition(arr: list, predicate: Callable[[Any], bool], use_rust: bool = True) -> int: + """ + Reorders the elements in the list in-place such that all elements for which + `predicate` returns True precede all elements for which it returns False. + + Does not guarantee stable relative ordering. + + Args: + arr: The list to modify in-place. + predicate: A callable returning True or False for each element. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + The boundary index pointing to the first element that returned False. + + Time Complexity: + O(n) where n is len(arr) + """ + if use_rust and RUST_AVAILABLE: + return _rust_partition(arr, predicate) + return _py_partition(arr, predicate) + + +def lower_bound(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None, use_rust: bool = True) -> int: + """ + Returns the index of the first element in the range that does not compare less than `val`. + + Args: + arr: The sorted list to search. + val: The value to search for. + comp: Optional custom binary comparator Callable(a, b) defining custom less-than. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + The index of the first element that is >= val, or len(arr) if not found. + + Time Complexity: + O(log n) + """ + if use_rust and RUST_AVAILABLE: + return _rust_lower_bound(arr, val, comp) + return _py_lower_bound(arr, val, comp) + + +def upper_bound(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None, use_rust: bool = True) -> int: + """ + Returns the index of the first element in the range that compares greater than `val`. + + Args: + arr: The sorted list to search. + val: The value to search for. + comp: Optional custom binary comparator Callable(a, b) defining custom less-than. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + The index of the first element that is > val, or len(arr) if not found. + + Time Complexity: + O(log n) + """ + if use_rust and RUST_AVAILABLE: + return _rust_upper_bound(arr, val, comp) + return _py_upper_bound(arr, val, comp) + + +def binary_search(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None, use_rust: bool = True) -> bool: + """ + Checks if a value is present in the sorted range. + + Args: + arr: The sorted list to search. + val: The value to search for. + comp: Optional custom binary comparator Callable(a, b) defining custom less-than. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + True if the element equivalent to val is found, False otherwise. + + Time Complexity: + O(log n) + """ + if use_rust and RUST_AVAILABLE: + return _rust_binary_search(arr, val, comp) + return _py_binary_search(arr, val, comp) + + +def equal_range(arr: list, val: Any, comp: Callable[[Any, Any], bool] = None, use_rust: bool = True) -> tuple[int, int]: + """ + Returns the range of elements equivalent to a given value. + + Args: + arr: The sorted list to search. + val: The value to search for. + comp: Optional custom binary comparator Callable(a, b) defining custom less-than. + use_rust: Whether to use the compiled Rust backend (default: True). + + Returns: + A tuple (lower_bound_index, upper_bound_index) defining the range of equivalent elements. + + Time Complexity: + O(log n) + """ + if use_rust and RUST_AVAILABLE: + return _rust_equal_range(arr, val, comp) + return _py_equal_range(arr, val, comp) + + +__all__ = [ + 'next_permutation', 'prev_permutation', 'nth_element', 'partition', + 'lower_bound', 'upper_bound', 'binary_search', 'equal_range', + 'RUST_AVAILABLE' +] diff --git a/pythonstl/facade/map.py b/pythonstl/facade/map.py index 5a9de18..3eb9232 100644 --- a/pythonstl/facade/map.py +++ b/pythonstl/facade/map.py @@ -6,9 +6,16 @@ from typing import TypeVar, Iterator as TypingIterator, Tuple from copy import deepcopy +from pythonstl.core.exceptions import KeyNotFoundError from pythonstl.implementations.associative._map_impl import _MapImpl from pythonstl.core.iterator import MapIterator +try: + from pythonstl._rust import RustMap + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + K = TypeVar('K') V = TypeVar('V') @@ -33,14 +40,19 @@ class stl_map: 2 """ - def __init__(self) -> None: + def __init__(self, use_rust: bool = True) -> None: """ Initialize an empty map. Time Complexity: O(1) """ - self._impl = _MapImpl() + if use_rust and RUST_AVAILABLE: + self._impl = RustMap() + self._is_rust = True + else: + self._impl = _MapImpl() + self._is_rust = False def insert(self, key: K, value: V) -> None: """ @@ -104,6 +116,10 @@ def at(self, key: K) -> V: Time Complexity: O(1) average case """ + if not self.find(key): + raise KeyNotFoundError(key) + if self._is_rust: + return self._impl.at(key) return self._impl.at(key) def empty(self) -> bool: @@ -140,6 +156,8 @@ def begin(self) -> MapIterator: Time Complexity: O(1) """ + if self._is_rust: + return MapIterator(dict(self._impl.get_data())) return self._impl.begin() def end(self) -> MapIterator: @@ -152,6 +170,8 @@ def end(self) -> MapIterator: Time Complexity: O(1) """ + if self._is_rust: + return MapIterator({}) return self._impl.end() def copy(self) -> 'stl_map': @@ -164,9 +184,11 @@ def copy(self) -> 'stl_map': Time Complexity: O(n) where n is the number of key-value pairs """ - new_map = stl_map() - for key, value in self: - new_map.insert(key, value) + new_map = stl_map(use_rust=self._is_rust) + if self._is_rust: + new_map._impl.set_data(self._impl.get_data()) + else: + new_map._impl._data = self._impl._data.copy() return new_map # Python magic methods @@ -226,12 +248,10 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, stl_map): return False - if self.size() != other.size(): - return False - for key, value in self: - if not other.find(key) or other.at(key) != value: - return False - return True + + self_data = dict(self._impl.get_data()) if self._is_rust else self._impl._data + other_data = dict(other._impl.get_data()) if other._is_rust else other._impl._data + return self_data == other_data def __iter__(self) -> TypingIterator[Tuple[K, V]]: """ @@ -240,6 +260,8 @@ def __iter__(self) -> TypingIterator[Tuple[K, V]]: Returns: Iterator over key-value pairs as tuples. """ + if self._is_rust: + return iter(self._impl.get_data()) return iter(self._impl.get_data().items()) def __copy__(self) -> 'stl_map': @@ -261,9 +283,14 @@ def __deepcopy__(self, memo) -> 'stl_map': Returns: A deep copy of the map. """ - new_map = stl_map() - for key, value in self: - new_map.insert(deepcopy(key, memo), deepcopy(value, memo)) + new_map = stl_map(use_rust=self._is_rust) + if self._is_rust: + new_pairs = [] + for k, v in self._impl.get_data(): + new_pairs.append((deepcopy(k, memo), deepcopy(v, memo))) + new_map._impl.set_data(new_pairs) + else: + new_map._impl._data = deepcopy(self._impl._data, memo) return new_map diff --git a/pythonstl/facade/priority_queue.py b/pythonstl/facade/priority_queue.py index 3e77cea..016aefe 100644 --- a/pythonstl/facade/priority_queue.py +++ b/pythonstl/facade/priority_queue.py @@ -6,8 +6,15 @@ from typing import TypeVar from copy import deepcopy +from pythonstl.core.exceptions import EmptyContainerError from pythonstl.implementations.heaps._priority_queue_impl import _PriorityQueueImpl +try: + from pythonstl._rust import RustPriorityQueue + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + T = TypeVar('T') @@ -38,7 +45,7 @@ class priority_queue: 10 """ - def __init__(self, comparator: str = "max") -> None: + def __init__(self, comparator: str = "max", use_rust: bool = True) -> None: """ Initialize an empty priority queue. @@ -49,7 +56,12 @@ def __init__(self, comparator: str = "max") -> None: Time Complexity: O(1) """ - self._impl = _PriorityQueueImpl(comparator) + if use_rust and RUST_AVAILABLE: + self._impl = RustPriorityQueue(comparator) + self._is_rust = True + else: + self._impl = _PriorityQueueImpl(comparator) + self._is_rust = False self._comparator = comparator def push(self, value: T) -> None: @@ -74,6 +86,8 @@ def pop(self) -> None: Time Complexity: O(log n) where n is the number of elements """ + if self.empty(): + raise EmptyContainerError("priority_queue") self._impl.pop() def top(self) -> T: @@ -89,6 +103,8 @@ def top(self) -> T: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("priority_queue") return self._impl.top() def empty(self) -> bool: @@ -125,8 +141,11 @@ def copy(self) -> 'priority_queue': Time Complexity: O(n) where n is the number of elements """ - new_pq = priority_queue(self._comparator) - new_pq._impl._data = self._impl._data.copy() + new_pq = priority_queue(self._comparator, use_rust=self._is_rust) + if self._is_rust: + new_pq._impl.set_data(self._impl.get_data()) + else: + new_pq._impl._data = self._impl._data.copy() return new_pq # Python magic methods @@ -170,8 +189,12 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, priority_queue): return False - return (self._comparator == other._comparator - and self._impl._data == other._impl._data) + if self._comparator != other._comparator: + return False + + self_data = self._impl.get_data() if self._is_rust else self._impl._data + other_data = other._impl.get_data() if other._is_rust else other._impl._data + return self_data == other_data def __copy__(self) -> 'priority_queue': """ @@ -192,8 +215,12 @@ def __deepcopy__(self, memo) -> 'priority_queue': Returns: A deep copy of the priority queue. """ - new_pq = priority_queue(self._comparator) - new_pq._impl._data = deepcopy(self._impl._data, memo) + new_pq = priority_queue(self._comparator, use_rust=self._is_rust) + if self._is_rust: + new_data = deepcopy(self._impl.get_data(), memo) + new_pq._impl.set_data(new_data) + else: + new_pq._impl._data = deepcopy(self._impl._data, memo) return new_pq diff --git a/pythonstl/facade/queue.py b/pythonstl/facade/queue.py index 4a08a6d..0a5a477 100644 --- a/pythonstl/facade/queue.py +++ b/pythonstl/facade/queue.py @@ -6,8 +6,15 @@ from typing import TypeVar from copy import deepcopy +from pythonstl.core.exceptions import EmptyContainerError from pythonstl.implementations.linear._queue_impl import _QueueImpl +try: + from pythonstl._rust import RustQueue + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + T = TypeVar('T') @@ -30,14 +37,19 @@ class queue: True """ - def __init__(self) -> None: + def __init__(self, use_rust: bool = True) -> None: """ Initialize an empty queue. Time Complexity: O(1) """ - self._impl = _QueueImpl() + if use_rust and RUST_AVAILABLE: + self._impl = RustQueue() + self._is_rust = True + else: + self._impl = _QueueImpl() + self._is_rust = False def push(self, value: T) -> None: """ @@ -61,6 +73,8 @@ def pop(self) -> None: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("queue") self._impl.pop() def front(self) -> T: @@ -76,6 +90,8 @@ def front(self) -> T: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("queue") return self._impl.front() def back(self) -> T: @@ -91,6 +107,8 @@ def back(self) -> T: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("queue") return self._impl.back() def empty(self) -> bool: @@ -127,9 +145,11 @@ def copy(self) -> 'queue': Time Complexity: O(n) where n is the number of elements """ - new_queue = queue() - # Copy internal deque - new_queue._impl._data = self._impl._data.copy() + new_queue = queue(use_rust=self._is_rust) + if self._is_rust: + new_queue._impl.set_data(self._impl.get_data()) + else: + new_queue._impl._data = self._impl._data.copy() return new_queue # Python magic methods @@ -159,7 +179,10 @@ def __repr__(self) -> str: Returns: String representation showing queue contents. """ - elements = [str(elem) for elem in self._impl._data] + if self._is_rust: + elements = [str(elem) for elem in self._impl.get_data()] + else: + elements = [str(elem) for elem in self._impl._data] return f"queue([{', '.join(elements)}])" def __eq__(self, other: object) -> bool: @@ -174,7 +197,10 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, queue): return False - return self._impl._data == other._impl._data + + self_data = self._impl.get_data() if self._is_rust else self._impl._data + other_data = other._impl.get_data() if other._is_rust else other._impl._data + return self_data == other_data def __copy__(self) -> 'queue': """ @@ -195,8 +221,12 @@ def __deepcopy__(self, memo) -> 'queue': Returns: A deep copy of the queue. """ - new_queue = queue() - new_queue._impl._data = deepcopy(self._impl._data, memo) + new_queue = queue(use_rust=self._is_rust) + if self._is_rust: + new_data = deepcopy(self._impl.get_data(), memo) + new_queue._impl.set_data(new_data) + else: + new_queue._impl._data = deepcopy(self._impl._data, memo) return new_queue diff --git a/pythonstl/facade/set.py b/pythonstl/facade/set.py index a6f8d7c..c1e8efd 100644 --- a/pythonstl/facade/set.py +++ b/pythonstl/facade/set.py @@ -9,6 +9,12 @@ from pythonstl.implementations.associative._set_impl import _SetImpl from pythonstl.core.iterator import SetIterator +try: + from pythonstl._rust import RustSet + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + T = TypeVar('T') @@ -32,14 +38,19 @@ class stl_set: 2 """ - def __init__(self) -> None: + def __init__(self, use_rust: bool = True) -> None: """ Initialize an empty set. Time Complexity: O(1) """ - self._impl = _SetImpl() + if use_rust and RUST_AVAILABLE: + self._impl = RustSet() + self._is_rust = True + else: + self._impl = _SetImpl() + self._is_rust = False def insert(self, value: T) -> None: """ @@ -117,7 +128,8 @@ def begin(self) -> SetIterator: Time Complexity: O(1) """ - return self._impl.begin() + data = self._impl.get_data() if self._is_rust else self._impl._data + return SetIterator(data) def end(self) -> SetIterator: """ @@ -129,7 +141,8 @@ def end(self) -> SetIterator: Time Complexity: O(1) """ - return self._impl.end() + # Return an exhausted iterator + return SetIterator(set()) def copy(self) -> 'stl_set': """ @@ -141,9 +154,11 @@ def copy(self) -> 'stl_set': Time Complexity: O(n) where n is the number of elements """ - new_set = stl_set() - for elem in self: - new_set.insert(elem) + new_set = stl_set(use_rust=self._is_rust) + if self._is_rust: + new_set._impl.set_data(self._impl.get_data()) + else: + new_set._impl._data = self._impl._data.copy() return new_set # Python magic methods @@ -203,12 +218,15 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, stl_set): return False - if self.size() != other.size(): - return False - for elem in self: - if not other.find(elem): - return False - return True + + self_data = self._impl.get_data() if self._is_rust else self._impl._data + other_data = other._impl.get_data() if other._is_rust else other._impl._data + + # BTreeSet elements are sorted, so direct list equality works for sorted comparison + if self._is_rust and other._is_rust: + return self_data == other_data + + return set(self_data) == set(other_data) def __iter__(self) -> TypingIterator[T]: """ @@ -217,6 +235,8 @@ def __iter__(self) -> TypingIterator[T]: Returns: Iterator over set elements. """ + if self._is_rust: + return iter(self._impl.get_data()) return iter(self._impl.get_data()) def __copy__(self) -> 'stl_set': @@ -238,9 +258,12 @@ def __deepcopy__(self, memo) -> 'stl_set': Returns: A deep copy of the set. """ - new_set = stl_set() - for elem in self: - new_set.insert(deepcopy(elem, memo)) + new_set = stl_set(use_rust=self._is_rust) + if self._is_rust: + new_data = deepcopy(self._impl.get_data(), memo) + new_set._impl.set_data(new_data) + else: + new_set._impl._data = deepcopy(self._impl._data, memo) return new_set diff --git a/pythonstl/facade/stack.py b/pythonstl/facade/stack.py index 6bd37fc..cb02dd8 100644 --- a/pythonstl/facade/stack.py +++ b/pythonstl/facade/stack.py @@ -6,8 +6,15 @@ from typing import TypeVar from copy import deepcopy +from pythonstl.core.exceptions import EmptyContainerError from pythonstl.implementations.linear._stack_impl import _StackImpl +try: + from pythonstl._rust import RustStack + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + T = TypeVar('T') @@ -30,14 +37,19 @@ class stack: True """ - def __init__(self) -> None: + def __init__(self, use_rust: bool = True) -> None: """ Initialize an empty stack. Time Complexity: O(1) """ - self._impl = _StackImpl() + if use_rust and RUST_AVAILABLE: + self._impl = RustStack() + self._is_rust = True + else: + self._impl = _StackImpl() + self._is_rust = False def push(self, value: T) -> None: """ @@ -61,6 +73,8 @@ def pop(self) -> None: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("stack") self._impl.pop() def top(self) -> T: @@ -76,6 +90,8 @@ def top(self) -> T: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("stack") return self._impl.top() def empty(self) -> bool: @@ -112,9 +128,11 @@ def copy(self) -> 'stack': Time Complexity: O(n) where n is the number of elements """ - new_stack = stack() - # Copy internal data - new_stack._impl._data = self._impl._data.copy() + new_stack = stack(use_rust=self._is_rust) + if self._is_rust: + new_stack._impl.set_data(self._impl.get_data()) + else: + new_stack._impl._data = self._impl._data.copy() return new_stack # Python magic methods @@ -144,7 +162,10 @@ def __repr__(self) -> str: Returns: String representation showing stack contents. """ - elements = [str(elem) for elem in self._impl._data] + if self._is_rust: + elements = [str(elem) for elem in self._impl.get_data()] + else: + elements = [str(elem) for elem in self._impl._data] return f"stack([{', '.join(elements)}])" def __eq__(self, other: object) -> bool: @@ -159,7 +180,10 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, stack): return False - return self._impl._data == other._impl._data + + self_data = self._impl.get_data() if self._is_rust else self._impl._data + other_data = other._impl.get_data() if other._is_rust else other._impl._data + return self_data == other_data def __copy__(self) -> 'stack': """ @@ -180,8 +204,12 @@ def __deepcopy__(self, memo) -> 'stack': Returns: A deep copy of the stack. """ - new_stack = stack() - new_stack._impl._data = deepcopy(self._impl._data, memo) + new_stack = stack(use_rust=self._is_rust) + if self._is_rust: + new_data = deepcopy(self._impl.get_data(), memo) + new_stack._impl.set_data(new_data) + else: + new_stack._impl._data = deepcopy(self._impl._data, memo) return new_stack diff --git a/pythonstl/facade/vector.py b/pythonstl/facade/vector.py index 90d32d6..0a6fd58 100644 --- a/pythonstl/facade/vector.py +++ b/pythonstl/facade/vector.py @@ -6,9 +6,16 @@ from typing import TypeVar, Iterator as TypingIterator from copy import deepcopy +from pythonstl.core.exceptions import EmptyContainerError, OutOfRangeError from pythonstl.implementations.linear._vector_impl import _VectorImpl from pythonstl.core.iterator import VectorIterator, VectorReverseIterator +try: + from pythonstl._rust import RustVector + RUST_AVAILABLE = True +except ImportError: + RUST_AVAILABLE = False + T = TypeVar('T') @@ -31,14 +38,19 @@ class vector: True """ - def __init__(self) -> None: + def __init__(self, use_rust: bool = True) -> None: """ Initialize an empty vector. Time Complexity: O(1) """ - self._impl = _VectorImpl() + if use_rust and RUST_AVAILABLE: + self._impl = RustVector() + self._is_rust = True + else: + self._impl = _VectorImpl() + self._is_rust = False def push_back(self, value: T) -> None: """ @@ -62,6 +74,8 @@ def pop_back(self) -> None: Time Complexity: O(1) """ + if self.empty(): + raise EmptyContainerError("vector") self._impl.pop_back() def at(self, index: int) -> T: @@ -80,6 +94,8 @@ def at(self, index: int) -> T: Time Complexity: O(1) """ + if index < 0 or index >= self.size(): + raise OutOfRangeError(index, self.size()) return self._impl.at(index) def insert(self, position: int, value: T) -> None: @@ -96,6 +112,8 @@ def insert(self, position: int, value: T) -> None: Time Complexity: O(n) where n is the number of elements after position """ + if position < 0 or position > self.size(): + raise OutOfRangeError(position, self.size()) self._impl.insert(position, value) def erase(self, position: int) -> None: @@ -111,6 +129,8 @@ def erase(self, position: int) -> None: Time Complexity: O(n) where n is the number of elements after position """ + if position < 0 or position >= self.size(): + raise OutOfRangeError(position, self.size()) self._impl.erase(position) def clear(self) -> None: @@ -158,7 +178,8 @@ def begin(self) -> VectorIterator: Time Complexity: O(1) """ - return self._impl.begin() + data = self._impl.get_data() if self._is_rust else self._impl._data + return VectorIterator(data, 0) def end(self) -> VectorIterator: """ @@ -170,7 +191,8 @@ def end(self) -> VectorIterator: Time Complexity: O(1) """ - return self._impl.end() + data = self._impl.get_data() if self._is_rust else self._impl._data + return VectorIterator(data, len(data)) def rbegin(self) -> VectorReverseIterator: """ @@ -182,7 +204,8 @@ def rbegin(self) -> VectorReverseIterator: Time Complexity: O(1) """ - return self._impl.rbegin() + data = self._impl.get_data() if self._is_rust else self._impl._data + return VectorReverseIterator(data) def rend(self) -> VectorReverseIterator: """ @@ -194,7 +217,8 @@ def rend(self) -> VectorReverseIterator: Time Complexity: O(1) """ - return self._impl.rend() + data = self._impl.get_data() if self._is_rust else self._impl._data + return VectorReverseIterator(data, -1) def size(self) -> int: """ @@ -242,9 +266,12 @@ def copy(self) -> 'vector': Time Complexity: O(n) where n is the number of elements """ - new_vector = vector() - for i in range(self.size()): - new_vector.push_back(self.at(i)) + new_vector = vector(use_rust=self._is_rust) + if self._is_rust: + new_vector._impl.set_data(self._impl.get_data()) + else: + new_vector._impl._data = self._impl._data.copy() + new_vector._impl._capacity = self._impl._capacity return new_vector # Python magic methods @@ -280,10 +307,9 @@ def __contains__(self, value: T) -> bool: Time Complexity: O(n) where n is the number of elements """ - for i in range(self.size()): - if self.at(i) == value: - return True - return False + if self._is_rust: + return value in self._impl.get_data() + return value in self._impl._data def __repr__(self) -> str: """ @@ -292,7 +318,10 @@ def __repr__(self) -> str: Returns: String representation showing all elements. """ - elements = [str(self.at(i)) for i in range(self.size())] + if self._is_rust: + elements = [str(elem) for elem in self._impl.get_data()] + else: + elements = [str(elem) for elem in self._impl._data] return f"vector([{', '.join(elements)}])" def __eq__(self, other: object) -> bool: @@ -307,12 +336,10 @@ def __eq__(self, other: object) -> bool: """ if not isinstance(other, vector): return False - if self.size() != other.size(): - return False - for i in range(self.size()): - if self.at(i) != other.at(i): - return False - return True + + self_data = self._impl.get_data() if self._is_rust else self._impl._data + other_data = other._impl.get_data() if other._is_rust else other._impl._data + return self_data == other_data def __lt__(self, other: 'vector') -> bool: """ @@ -324,13 +351,16 @@ def __lt__(self, other: 'vector') -> bool: Returns: True if this vector is lexicographically less than other. """ - min_size = min(self.size(), other.size()) + self_data = self._impl.get_data() if self._is_rust else self._impl._data + other_data = other._impl.get_data() if other._is_rust else other._impl._data + + min_size = min(len(self_data), len(other_data)) for i in range(min_size): - if self.at(i) < other.at(i): + if self_data[i] < other_data[i]: return True - elif self.at(i) > other.at(i): + elif self_data[i] > other_data[i]: return False - return self.size() < other.size() + return len(self_data) < len(other_data) def __iter__(self) -> TypingIterator[T]: """ @@ -339,6 +369,8 @@ def __iter__(self) -> TypingIterator[T]: Returns: Iterator over vector elements. """ + if self._is_rust: + return iter(self._impl.get_data()) return iter(self._impl.get_data()) def __copy__(self) -> 'vector': @@ -360,9 +392,13 @@ def __deepcopy__(self, memo) -> 'vector': Returns: A deep copy of the vector. """ - new_vector = vector() - for i in range(self.size()): - new_vector.push_back(deepcopy(self.at(i), memo)) + new_vector = vector(use_rust=self._is_rust) + if self._is_rust: + new_data = deepcopy(self._impl.get_data(), memo) + new_vector._impl.set_data(new_data) + else: + new_vector._impl._data = deepcopy(self._impl._data, memo) + new_vector._impl._capacity = self._impl._capacity return new_vector diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a2c0dde --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,733 @@ +use pyo3::prelude::*; +use pyo3::types::PyList; +use std::collections::{VecDeque, BTreeSet, BTreeMap}; +use std::cmp::Ordering; + +// ----------------- PyObjectOrd Bridge ----------------- + +/// A wrapper for PyObject to enable sorted indexing inside Rust's BTreeSet/BTreeMap. +/// It delegates Eq, Ord, PartialEq, and PartialOrd to Python rich comparisons. +#[derive(Clone)] +struct PyObjectOrd(PyObject); + +impl PartialEq for PyObjectOrd { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + self.0.bind(py).eq(other.0.bind(py)).unwrap_or(false) + }) + } +} + +impl Eq for PyObjectOrd {} + +impl PartialOrd for PyObjectOrd { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PyObjectOrd { + fn cmp(&self, other: &Self) -> Ordering { + Python::with_gil(|py| { + let self_ref = self.0.bind(py); + let other_ref = other.0.bind(py); + if self_ref.eq(other_ref).unwrap_or(false) { + Ordering::Equal + } else if self_ref.lt(other_ref).unwrap_or(false) { + Ordering::Less + } else { + Ordering::Greater + } + }) + } +} + +// ----------------- RustStack ----------------- + +#[pyclass] +struct RustStack { + data: Vec, +} + +#[pymethods] +impl RustStack { + #[new] + fn new() -> Self { + RustStack { data: Vec::new() } + } + + fn push(&mut self, value: PyObject) { + self.data.push(value); + } + + fn pop(&mut self) -> PyResult<()> { + if !self.data.is_empty() { + self.data.pop(); + } + Ok(()) + } + + fn top(&self, py: Python) -> PyResult> { + if self.data.is_empty() { + Ok(None) + } else { + Ok(Some(self.data.last().unwrap().clone_ref(py))) + } + } + + fn empty(&self) -> bool { + self.data.is_empty() + } + + fn size(&self) -> usize { + self.data.len() + } + + fn get_data(&self, py: Python) -> PyResult> { + Ok(self.data.iter().map(|x| x.clone_ref(py)).collect()) + } + + fn set_data(&mut self, new_data: Vec) { + self.data = new_data; + } +} + +// ----------------- RustQueue ----------------- + +#[pyclass] +struct RustQueue { + data: VecDeque, +} + +#[pymethods] +impl RustQueue { + #[new] + fn new() -> Self { + RustQueue { data: VecDeque::new() } + } + + fn push(&mut self, value: PyObject) { + self.data.push_back(value); + } + + fn pop(&mut self) -> PyResult<()> { + if !self.data.is_empty() { + self.data.pop_front(); + } + Ok(()) + } + + fn front(&self, py: Python) -> PyResult> { + if self.data.is_empty() { + Ok(None) + } else { + Ok(Some(self.data.front().unwrap().clone_ref(py))) + } + } + + fn back(&self, py: Python) -> PyResult> { + if self.data.is_empty() { + Ok(None) + } else { + Ok(Some(self.data.back().unwrap().clone_ref(py))) + } + } + + fn empty(&self) -> bool { + self.data.is_empty() + } + + fn size(&self) -> usize { + self.data.len() + } + + fn get_data(&self, py: Python) -> PyResult> { + Ok(self.data.iter().map(|x| x.clone_ref(py)).collect()) + } + + fn set_data(&mut self, new_data: Vec) { + self.data = new_data.into(); + } +} + +// ----------------- RustVector ----------------- + +#[pyclass] +struct RustVector { + data: Vec, +} + +#[pymethods] +impl RustVector { + #[new] + fn new() -> Self { + RustVector { data: Vec::new() } + } + + fn push_back(&mut self, value: PyObject) { + self.data.push(value); + } + + fn pop_back(&mut self) -> PyResult<()> { + if !self.data.is_empty() { + self.data.pop(); + } + Ok(()) + } + + fn at(&self, index: usize, py: Python) -> PyResult> { + if index >= self.data.len() { + Ok(None) + } else { + Ok(Some(self.data[index].clone_ref(py))) + } + } + + fn insert(&mut self, index: usize, value: PyObject) -> PyResult<()> { + if index <= self.data.len() { + self.data.insert(index, value); + } + Ok(()) + } + + fn erase(&mut self, index: usize) -> PyResult<()> { + if index < self.data.len() { + self.data.remove(index); + } + Ok(()) + } + + fn clear(&mut self) { + self.data.clear(); + } + + fn reserve(&mut self, capacity: usize) { + self.data.reserve(capacity); + } + + fn shrink_to_fit(&mut self) { + self.data.shrink_to_fit(); + } + + fn size(&self) -> usize { + self.data.len() + } + + fn capacity(&self) -> usize { + self.data.capacity() + } + + fn empty(&self) -> bool { + self.data.is_empty() + } + + fn get_data(&self, py: Python) -> PyResult> { + Ok(self.data.iter().map(|x| x.clone_ref(py)).collect()) + } + + fn set_data(&mut self, new_data: Vec) { + self.data = new_data; + } +} + +// ----------------- RustSet ----------------- + +#[pyclass] +struct RustSet { + data: BTreeSet, +} + +#[pymethods] +impl RustSet { + #[new] + fn new() -> Self { + RustSet { data: BTreeSet::new() } + } + + fn insert(&mut self, value: PyObject) -> bool { + self.data.insert(PyObjectOrd(value)) + } + + fn erase(&mut self, value: PyObject) -> bool { + self.data.remove(&PyObjectOrd(value)) + } + + fn find(&self, value: PyObject) -> bool { + self.data.contains(&PyObjectOrd(value)) + } + + fn empty(&self) -> bool { + self.data.is_empty() + } + + fn size(&self) -> usize { + self.data.len() + } + + fn get_data(&self, py: Python) -> PyResult> { + Ok(self.data.iter().map(|x| x.0.clone_ref(py)).collect()) + } + + fn set_data(&mut self, new_data: Vec) { + self.data = new_data.into_iter().map(PyObjectOrd).collect(); + } +} + +// ----------------- RustMap ----------------- + +#[pyclass] +struct RustMap { + data: BTreeMap, +} + +#[pymethods] +impl RustMap { + #[new] + fn new() -> Self { + RustMap { data: BTreeMap::new() } + } + + fn insert(&mut self, key: PyObject, value: PyObject) { + self.data.insert(PyObjectOrd(key), value); + } + + fn erase(&mut self, key: PyObject) -> bool { + self.data.remove(&PyObjectOrd(key)).is_some() + } + + fn find(&self, key: PyObject) -> bool { + self.data.contains_key(&PyObjectOrd(key)) + } + + fn at(&self, key: PyObject, py: Python) -> PyResult> { + if let Some(val) = self.data.get(&PyObjectOrd(key)) { + Ok(Some(val.clone_ref(py))) + } else { + Ok(None) + } + } + + fn empty(&self) -> bool { + self.data.is_empty() + } + + fn size(&self) -> usize { + self.data.len() + } + + fn get_data(&self, py: Python) -> PyResult> { + Ok(self.data.iter().map(|(k, v)| (k.0.clone_ref(py), v.clone_ref(py))).collect()) + } + + fn set_data(&mut self, new_data: Vec<(PyObject, PyObject)>) { + self.data = new_data.into_iter().map(|(k, v)| (PyObjectOrd(k), v)).collect(); + } +} + +// ----------------- RustPriorityQueue ----------------- + +#[pyclass] +struct RustPriorityQueue { + data: Vec, + comparator: String, +} + +#[pymethods] +impl RustPriorityQueue { + #[new] + fn new(comparator: Option) -> Self { + RustPriorityQueue { + data: Vec::new(), + comparator: comparator.unwrap_or_else(|| "max".to_string()), + } + } + + fn push(&mut self, value: PyObject) { + self.data.push(PyObjectOrd(value)); + self.sift_up(self.data.len() - 1); + } + + fn pop(&mut self) -> PyResult<()> { + if !self.data.is_empty() { + let last_idx = self.data.len() - 1; + self.data.swap(0, last_idx); + self.data.pop(); + if !self.data.is_empty() { + self.sift_down(0); + } + } + Ok(()) + } + + fn top(&self, py: Python) -> PyResult> { + if self.data.is_empty() { + Ok(None) + } else { + Ok(Some(self.data[0].0.clone_ref(py))) + } + } + + fn empty(&self) -> bool { + self.data.is_empty() + } + + fn size(&self) -> usize { + self.data.len() + } + + fn get_data(&self, py: Python) -> PyResult> { + Ok(self.data.iter().map(|x| x.0.clone_ref(py)).collect()) + } + + fn set_data(&mut self, new_data: Vec) { + self.data = new_data.into_iter().map(PyObjectOrd).collect(); + } +} + +impl RustPriorityQueue { + fn sift_up(&mut self, mut idx: usize) { + while idx > 0 { + let parent = (idx - 1) / 2; + if self.is_higher_priority(&self.data[idx], &self.data[parent]) { + self.data.swap(idx, parent); + idx = parent; + } else { + break; + } + } + } + + fn sift_down(&mut self, mut idx: usize) { + let len = self.data.len(); + loop { + let left = 2 * idx + 1; + let right = 2 * idx + 2; + let mut highest = idx; + + if left < len && self.is_higher_priority(&self.data[left], &self.data[highest]) { + highest = left; + } + if right < len && self.is_higher_priority(&self.data[right], &self.data[highest]) { + highest = right; + } + + if highest != idx { + self.data.swap(idx, highest); + idx = highest; + } else { + break; + } + } + } + + fn is_higher_priority(&self, a: &PyObjectOrd, b: &PyObjectOrd) -> bool { + if self.comparator == "min" { + a < b + } else { + a > b + } + } +} + +// ----------------- Bubble Sort Benchmark ----------------- + +#[pyfunction] +fn bubble_sort(mut arr: Vec) -> PyResult> { + let len = arr.len(); + if len > 0 { + for i in 0..len { + for j in 0..len - 1 - i { + if arr[j] > arr[j + 1] { + arr.swap(j, j + 1); + } + } + } + } + Ok(arr) +} + +// ----------------- C++ STL Algorithms ----------------- + +#[pyfunction] +fn next_permutation(py: Python, arr: &Bound<'_, PyList>) -> PyResult { + let mut vec: Vec = arr.extract()?; + if vec.len() <= 1 { + return Ok(false); + } + + let mut i = vec.len() - 2; + let mut found = false; + loop { + let current = vec[i].bind(py); + let next = vec[i + 1].bind(py); + if current.lt(next).unwrap_or(false) { + found = true; + break; + } + if i == 0 { + break; + } + i -= 1; + } + + if !found { + vec.reverse(); + for (idx, val) in vec.iter().enumerate() { + arr.set_item(idx, val)?; + } + return Ok(false); + } + + let mut j = vec.len() - 1; + while j > i { + if vec[j].bind(py).gt(vec[i].bind(py)).unwrap_or(false) { + break; + } + j -= 1; + } + + vec.swap(i, j); + vec[i + 1..].reverse(); + + for (idx, val) in vec.iter().enumerate() { + arr.set_item(idx, val)?; + } + + Ok(true) +} + +#[pyfunction] +fn prev_permutation(py: Python, arr: &Bound<'_, PyList>) -> PyResult { + let mut vec: Vec = arr.extract()?; + if vec.len() <= 1 { + return Ok(false); + } + + let mut i = vec.len() - 2; + let mut found = false; + loop { + let current = vec[i].bind(py); + let next = vec[i + 1].bind(py); + if current.gt(next).unwrap_or(false) { + found = true; + break; + } + if i == 0 { + break; + } + i -= 1; + } + + if !found { + vec.reverse(); + for (idx, val) in vec.iter().enumerate() { + arr.set_item(idx, val)?; + } + return Ok(false); + } + + let mut j = vec.len() - 1; + while j > i { + if vec[j].bind(py).lt(vec[i].bind(py)).unwrap_or(false) { + break; + } + j -= 1; + } + + vec.swap(i, j); + vec[i + 1..].reverse(); + + for (idx, val) in vec.iter().enumerate() { + arr.set_item(idx, val)?; + } + + Ok(true) +} + +#[pyfunction] +fn nth_element(_py: Python, arr: &Bound<'_, PyList>, nth: usize) -> PyResult<()> { + let mut vec: Vec = arr.extract()?; + let len = vec.len(); + if nth < len { + quickselect(&mut vec, 0, len - 1, nth); + for (i, val) in vec.iter().enumerate() { + arr.set_item(i, val)?; + } + } + Ok(()) +} + +fn quickselect(arr: &mut Vec, left: usize, right: usize, nth: usize) { + if left >= right { + return; + } + let pivot_idx = partition_q(arr, left, right); + if pivot_idx == nth { + return; + } else if pivot_idx > nth { + if pivot_idx > 0 { + quickselect(arr, left, pivot_idx - 1, nth); + } + } else { + quickselect(arr, pivot_idx + 1, right, nth); + } +} + +fn partition_q(arr: &mut Vec, left: usize, right: usize) -> usize { + let pivot_idx = left + (right - left) / 2; + arr.swap(pivot_idx, right); + let mut i = left; + Python::with_gil(|py| { + let pivot_val = arr[right].clone_ref(py); + let pivot_bound = pivot_val.bind(py); + for j in left..right { + if arr[j].bind(py).lt(pivot_bound).unwrap_or(false) { + arr.swap(i, j); + i += 1; + } + } + }); + arr.swap(i, right); + i +} + +#[pyfunction] +fn partition(py: Python, arr: &Bound<'_, PyList>, predicate: PyObject) -> PyResult { + let mut vec: Vec = arr.extract()?; + let mut i = 0; + for j in 0..vec.len() { + let val = vec[j].clone_ref(py); + let is_true: bool = predicate.call1(py, (val,))?.extract(py)?; + if is_true { + vec.swap(i, j); + i += 1; + } + } + for (idx, val) in vec.iter().enumerate() { + arr.set_item(idx, val)?; + } + Ok(i) +} + +fn lower_bound_impl(py: Python, arr: &Bound<'_, PyList>, val: &PyObject, comp: &Option) -> PyResult { + let len = arr.len(); + let mut left = 0; + let mut right = len; + + while left < right { + let mid = left + (right - left) / 2; + let mid_val = arr.get_item(mid)?; + + let is_less = match comp { + Some(c) => { + let mid_obj = mid_val.to_object(py); + let res: bool = c.call1(py, (mid_obj, val.clone_ref(py)))?.extract(py)?; + res + } + None => { + mid_val.lt(val)? + } + }; + + if is_less { + left = mid + 1; + } else { + right = mid; + } + } + Ok(left) +} + +fn upper_bound_impl(py: Python, arr: &Bound<'_, PyList>, val: &PyObject, comp: &Option) -> PyResult { + let len = arr.len(); + let mut left = 0; + let mut right = len; + + while left < right { + let mid = left + (right - left) / 2; + let mid_val = arr.get_item(mid)?; + + let is_less = match comp { + Some(c) => { + let mid_obj = mid_val.to_object(py); + let res: bool = c.call1(py, (val.clone_ref(py), mid_obj))?.extract(py)?; + res + } + None => { + val.bind(py).lt(&mid_val)? + } + }; + + if is_less { + right = mid; + } else { + left = mid + 1; + } + } + Ok(left) +} + +#[pyfunction] +fn lower_bound(py: Python, arr: &Bound<'_, PyList>, val: PyObject, comp: Option) -> PyResult { + lower_bound_impl(py, arr, &val, &comp) +} + +#[pyfunction] +fn upper_bound(py: Python, arr: &Bound<'_, PyList>, val: PyObject, comp: Option) -> PyResult { + upper_bound_impl(py, arr, &val, &comp) +} + +#[pyfunction] +fn binary_search(py: Python, arr: &Bound<'_, PyList>, val: PyObject, comp: Option) -> PyResult { + let len = arr.len(); + if len == 0 { + return Ok(false); + } + let idx = lower_bound_impl(py, arr, &val, &comp)?; + if idx < len { + let elem = arr.get_item(idx)?; + let eq = match &comp { + Some(c) => { + let elem_obj = elem.to_object(py); + let less1: bool = c.call1(py, (elem_obj.clone(), val.clone_ref(py)))?.extract(py)?; + let less2: bool = c.call1(py, (val.clone_ref(py), elem_obj))?.extract(py)?; + !less1 && !less2 + } + None => { + elem.eq(&val)? + } + }; + Ok(eq) + } else { + Ok(false) + } +} + +#[pyfunction] +fn equal_range(py: Python, arr: &Bound<'_, PyList>, val: PyObject, comp: Option) -> PyResult<(usize, usize)> { + let lb = lower_bound_impl(py, arr, &val, &comp)?; + let ub = upper_bound_impl(py, arr, &val, &comp)?; + Ok((lb, ub)) +} + +// ----------------- Module Registration ----------------- + +#[pymodule] +fn _rust(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(bubble_sort, m)?)?; + m.add_function(wrap_pyfunction!(next_permutation, m)?)?; + m.add_function(wrap_pyfunction!(prev_permutation, m)?)?; + m.add_function(wrap_pyfunction!(nth_element, m)?)?; + m.add_function(wrap_pyfunction!(partition, m)?)?; + m.add_function(wrap_pyfunction!(lower_bound, m)?)?; + m.add_function(wrap_pyfunction!(upper_bound, m)?)?; + m.add_function(wrap_pyfunction!(binary_search, m)?)?; + m.add_function(wrap_pyfunction!(equal_range, m)?)?; + Ok(()) +} diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py new file mode 100644 index 0000000..fc5f9ec --- /dev/null +++ b/tests/test_algorithms.py @@ -0,0 +1,77 @@ +import pytest +from pythonstl import next_permutation, prev_permutation, nth_element, partition +from pythonstl.facade.algorithms import RUST_AVAILABLE + +# Run tests on both implementations (Rust and pure-Python) +PARAMS = [False] +if RUST_AVAILABLE: + PARAMS.append(True) + +@pytest.mark.parametrize("use_rust", PARAMS) +def test_next_permutation(use_rust): + # Basic sorted case + arr = [1, 2, 3] + has_next = next_permutation(arr, use_rust=use_rust) + assert has_next is True + assert arr == [1, 3, 2] + + # Boundary/Last permutation case + arr = [3, 2, 1] + has_next = next_permutation(arr, use_rust=use_rust) + assert has_next is False + assert arr == [1, 2, 3] + + # Duplicates case + arr = [1, 1, 5] + assert next_permutation(arr, use_rust=use_rust) is True + assert arr == [1, 5, 1] + assert next_permutation(arr, use_rust=use_rust) is True + assert arr == [5, 1, 1] + assert next_permutation(arr, use_rust=use_rust) is False + assert arr == [1, 1, 5] + +@pytest.mark.parametrize("use_rust", PARAMS) +def test_prev_permutation(use_rust): + # Basic descending case + arr = [3, 2, 1] + has_prev = prev_permutation(arr, use_rust=use_rust) + assert has_prev is True + assert arr == [3, 1, 2] + + # Boundary/First permutation case + arr = [1, 2, 3] + has_prev = prev_permutation(arr, use_rust=use_rust) + assert has_prev is False + assert arr == [3, 2, 1] + +@pytest.mark.parametrize("use_rust", PARAMS) +def test_nth_element(use_rust): + # Find median (nth = 4 on 9 elements) + arr = [9, 7, 5, 1, 2, 3, 6, 4, 8] + nth = 4 + nth_element(arr, nth, use_rust=use_rust) + + val = arr[nth] + assert val == 5 + for i in range(nth): + assert arr[i] <= val + for i in range(nth + 1, len(arr)): + assert arr[i] >= val + + # Single element and simple boundaries + arr = [2, 1] + nth_element(arr, 0, use_rust=use_rust) + assert arr[0] == 1 + assert arr[1] == 2 + +@pytest.mark.parametrize("use_rust", PARAMS) +def test_partition(use_rust): + # Partition even numbers to the front + arr = [1, 2, 3, 4, 5, 6, 7, 8] + boundary = partition(arr, lambda x: x % 2 == 0, use_rust=use_rust) + + assert boundary == 4 + for i in range(boundary): + assert arr[i] % 2 == 0 + for i in range(boundary, len(arr)): + assert arr[i] % 2 != 0 diff --git a/tests/test_binary_search.py b/tests/test_binary_search.py new file mode 100644 index 0000000..132b86f --- /dev/null +++ b/tests/test_binary_search.py @@ -0,0 +1,104 @@ +import pytest +from pythonstl import lower_bound, upper_bound, binary_search, equal_range + +# Create a custom class to test custom comparator +class Item: + def __init__(self, val): + self.val = val + def __repr__(self): + return f"Item({self.val})" + + +def test_lower_bound_std(): + for use_rust in [True, False]: + arr = [1, 2, 2, 2, 3, 5, 8] + # Element present + assert lower_bound(arr, 2, use_rust=use_rust) == 1 + assert lower_bound(arr, 3, use_rust=use_rust) == 4 + # Element not present, fits in middle + assert lower_bound(arr, 4, use_rust=use_rust) == 5 + # Element smaller than all + assert lower_bound(arr, 0, use_rust=use_rust) == 0 + # Element larger than all + assert lower_bound(arr, 10, use_rust=use_rust) == len(arr) + + +def test_upper_bound_std(): + for use_rust in [True, False]: + arr = [1, 2, 2, 2, 3, 5, 8] + # Element present + assert upper_bound(arr, 2, use_rust=use_rust) == 4 + assert upper_bound(arr, 3, use_rust=use_rust) == 5 + # Element not present, fits in middle + assert upper_bound(arr, 4, use_rust=use_rust) == 5 + # Element smaller than all + assert upper_bound(arr, 0, use_rust=use_rust) == 0 + # Element larger than all + assert upper_bound(arr, 10, use_rust=use_rust) == len(arr) + + +def test_binary_search_std(): + for use_rust in [True, False]: + arr = [1, 2, 2, 2, 3, 5, 8] + # Element present + assert binary_search(arr, 2, use_rust=use_rust) is True + assert binary_search(arr, 3, use_rust=use_rust) is True + assert binary_search(arr, 5, use_rust=use_rust) is True + # Element not present + assert binary_search(arr, 4, use_rust=use_rust) is False + assert binary_search(arr, 0, use_rust=use_rust) is False + assert binary_search(arr, 10, use_rust=use_rust) is False + + +def test_equal_range_std(): + for use_rust in [True, False]: + arr = [1, 2, 2, 2, 3, 5, 8] + assert equal_range(arr, 2, use_rust=use_rust) == (1, 4) + assert equal_range(arr, 3, use_rust=use_rust) == (4, 5) + assert equal_range(arr, 4, use_rust=use_rust) == (5, 5) + assert equal_range(arr, 0, use_rust=use_rust) == (0, 0) + assert equal_range(arr, 10, use_rust=use_rust) == (len(arr), len(arr)) + + +def test_empty_and_single(): + for use_rust in [True, False]: + # Empty list + arr = [] + assert lower_bound(arr, 5, use_rust=use_rust) == 0 + assert upper_bound(arr, 5, use_rust=use_rust) == 0 + assert binary_search(arr, 5, use_rust=use_rust) is False + assert equal_range(arr, 5, use_rust=use_rust) == (0, 0) + + # Single element + arr = [5] + assert lower_bound(arr, 3, use_rust=use_rust) == 0 + assert lower_bound(arr, 5, use_rust=use_rust) == 0 + assert lower_bound(arr, 7, use_rust=use_rust) == 1 + + assert upper_bound(arr, 3, use_rust=use_rust) == 0 + assert upper_bound(arr, 5, use_rust=use_rust) == 1 + assert upper_bound(arr, 7, use_rust=use_rust) == 1 + + assert binary_search(arr, 5, use_rust=use_rust) is True + assert binary_search(arr, 3, use_rust=use_rust) is False + assert equal_range(arr, 5, use_rust=use_rust) == (0, 1) + + +def test_custom_comparator(): + # Comparator returns True if element < val + def item_comp(item1, item2): + return item1.val < item2.val + + for use_rust in [True, False]: + arr = [Item(1), Item(2), Item(2), Item(3)] + val = Item(2) + + assert lower_bound(arr, val, comp=item_comp, use_rust=use_rust) == 1 + assert upper_bound(arr, val, comp=item_comp, use_rust=use_rust) == 3 + assert binary_search(arr, val, comp=item_comp, use_rust=use_rust) is True + assert equal_range(arr, val, comp=item_comp, use_rust=use_rust) == (1, 3) + + # Element not present + val_not = Item(4) + assert lower_bound(arr, val_not, comp=item_comp, use_rust=use_rust) == 4 + assert binary_search(arr, val_not, comp=item_comp, use_rust=use_rust) is False