diff --git a/README.md b/README.md
index 6e446ac..e305c70 100644
--- a/README.md
+++ b/README.md
@@ -320,20 +320,20 @@ Full Python integration while maintaining STL compatibility:
- Copy protocol support
- Maintains backward compatibility
-## 📊 Performance Benchmarks
+## Performance Benchmarks
PythonSTL includes a compiled Rust backend (built with PyO3 and Maturin) for high-performance operations, alongside pure-Python fallbacks. Below are the actual performance comparison results against pure-Python and native C++ (compiled with `g++ -O3`).
-### 1. Containers Performance (50,000 Operations)
+### 1. Containers Performance Benchmarks (3-Way Comparison)
-| Container Class | Pure Python | Python + Rust | Speedup Status | Design / Algorithmic Trade-off |
-| :--- | :--- | :--- | :--- | :--- |
-| **Stack** | 0.2324s | 0.1581s | **1.47x faster** | Linear stack operations. Limited by FFI call overhead. |
-| **Queue** | 0.2428s | 0.1608s | **1.51x faster** | FIFO operations. Limited by FFI call overhead. |
-| **Vector** | 0.0041s | 0.0034s | **1.20x faster** | Push_back & random access indices. Limited by FFI. |
-| **Set** | 0.0216s | 0.1111s | *0.19x faster* | **Sorted Set vs Unordered Hash Set** (replicates C++ B-Tree structure) |
-| **Map** | 0.0389s | 0.1959s | *0.20x faster* | **Sorted Map vs Unordered Hash Map** (replicates C++ B-Tree structure) |
-| **Priority Queue**| 0.0764s | 0.0959s | *0.80x faster* | Custom binary heap vs. C-optimized `heapq` module. |
+| Container Class | Pure Python (STL) | Python + Rust (STL) | Native Built-in | Rust Speedup | Design / Algorithmic Trade-off |
+| :--- | :--- | :--- | :--- | :--- | :--- |
+| **Stack** | 0.2441s | 0.2178s | 0.0667s | **1.12x faster** | Linear stack operations. Limited by FFI call overhead. |
+| **Queue** | 0.2445s | 0.2078s | 0.0520s | **1.18x faster** | FIFO operations. Limited by FFI call overhead. |
+| **Vector** | 0.0065s | 0.0038s | 0.0015s | **1.70x faster** | Push_back & random access indices. Limited by FFI. |
+| **Set** | 0.1572s | 0.0197s | 0.0014s | **8.00x faster** | AVL Tree (Python) vs. BTree (Rust) vs. Unordered Hash Set (Native). |
+| **Map** | 0.1632s | 0.0347s | 0.0020s | **4.70x faster** | AVL Tree (Python) vs. BTree (Rust) vs. Unordered Hash Map (Native). |
+| **Priority Queue**| 0.0238s | 0.0371s | 0.0054s | *0.64x faster* | Custom binary heap vs. C-optimized `heapq` module. |
* **Sorted Trees vs. Hash Tables**: Python's native `set` and `dict` are highly optimized $O(1)$ hash tables written in C. PythonSTL sets/maps replicate C++'s `std::set`/`std::map` using sorted trees (`BTreeSet`/`BTreeMap`), which run in $O(\log N)$ and sort keys.
* **FFI overhead**: Storing arbitrary Python objects in Rust requires acquiring the GIL and calling back into the Python VM for comparisons, creating high FFI boundaries.
@@ -388,7 +388,7 @@ pytest tests/
pytest tests/ --cov=pythonstl --cov-report=html
```
-## 🛠️ Development
+## Development
### Setup
@@ -411,15 +411,22 @@ flake8 pythonstl/
pytest && mypy pythonstl/ && flake8 pythonstl/
```
-## Note
-➡️ The goal is NOT to replace Python built-ins.
-➡️ The goal is to provide: 1) Conceptual clarity 2) STL familiarity for C++ developers 3) A structured learning bridge for DSA
+## ❓ Myths & Common Misconceptions
+
+### Myth 1: "This library has no actual use because online platforms (LeetCode, Codeforces) don't support external libraries."
+* **Reality:** Correct! You cannot import external packages during live programming contests. The goal of `PythonSTL` is **not** to be used in live contests. Instead, it serves as a local prototyping, learning, and transition tool. C++ developers moving to Python can use it locally to adapt their mental model of STL data structures to Python's syntax, and to practice structure design during mock interviews.
+
+### Myth 2: "Python's native structures are better and faster, so the Rust backend is unnecessary over-engineering."
+* **Reality:** While Python's built-ins are great, Python actually lacks native equivalents for some core STL behaviors:
+ - **No Sorted Set/Map:** Python's built-in `set` and `dict` are hash-based and unordered. To maintain a sorted collection, you'd have to sort repeatedly ($\mathcal{O}(n \log n)$). `PythonSTL`'s Rust backend provides a true $\mathcal{O}(\log n)$ sorted `BTreeSet` and `BTreeMap` (and the pure-Python fallback uses a balanced AVL Tree).
+ - **No Customizable Priority Queue:** Python’s `heapq` is strictly a min-heap, and custom comparators are difficult to write. `PythonSTL` provides max/min heaps and custom sorting keys out-of-the-box.
+ - **Engineering Showcase:** The Rust backend built via Maturin and PyO3 demonstrates a hybrid performance architecture. In real-world projects (like Polars, Pydantic, or cryptography libraries), performance-critical loops are written in compiled languages and bound to Python. This library serves as an educational blueprint for that pattern.
-## 📝 License
+## License
MIT License - see LICENSE file for details.
-## 🤝 Contributing
+## Contributing
Contributions are welcome! Please:
1. Fork the repository
@@ -433,5 +440,6 @@ Contributions are welcome! Please:
- GitHub: [@AnshMNSoni](https://github.com/AnshMNSoni)
- Issues: [GitHub Issues](https://github.com/AnshMNSoni/PythonSTL/issues)
+- Linkedin: [@anshmnsoni](https://linkedin.com/in/anshmnsoni)
-**PythonSTL v1.1.6** - Bringing C++ STL elegance to Python
\ No newline at end of file
+**PythonSTL v1.1.7** - Bringing C++ STL elegance to Python
\ No newline at end of file
diff --git a/benchmarks/benchmark_all_structures.py b/benchmarks/benchmark_all_structures.py
index ab4df55..e7428f8 100644
--- a/benchmarks/benchmark_all_structures.py
+++ b/benchmarks/benchmark_all_structures.py
@@ -1,6 +1,12 @@
+"""
+Performance benchmark comparing C++ (Rust backend), Pure Python, and Native Python built-in data structures.
+"""
+
import time
import sys
import gc
+import heapq
+from collections import deque
from pathlib import Path
# Add project root to path to run directly from development folder
@@ -14,15 +20,17 @@
from pythonstl.facade.map import RUST_AVAILABLE as MAP_RUST_AVAILABLE
from pythonstl.facade.priority_queue import RUST_AVAILABLE as PQ_RUST_AVAILABLE
-def run_benchmark(name, py_func, rust_func, has_rust):
+
+def run_benchmark(name, py_func, rust_func, has_rust, native_func):
print(f"Benchmarking {name}...")
- # Run Python benchmark
+ # 1. Run Pure Python benchmark
gc.collect()
start = time.perf_counter()
py_func()
py_time = time.perf_counter() - start
+ # 2. Run Rust benchmark
rust_time = None
if has_rust:
gc.collect()
@@ -30,10 +38,18 @@ def run_benchmark(name, py_func, rust_func, has_rust):
rust_func()
rust_time = time.perf_counter() - start
- return py_time, rust_time
+ # 3. Run Native Python built-in benchmark
+ gc.collect()
+ start = time.perf_counter()
+ native_func()
+ native_time = time.perf_counter() - start
+
+ return py_time, rust_time, native_time
-# ----------------- Benchmark Workloads -----------------
+# ==========================================
+# 1. STACK WORKLOADS
+# ==========================================
def bench_stack_py():
s = stack(use_rust=False)
for i in range(500000):
@@ -41,6 +57,7 @@ def bench_stack_py():
for _ in range(500000):
s.pop()
+
def bench_stack_rust():
s = stack(use_rust=True)
for i in range(500000):
@@ -48,6 +65,18 @@ def bench_stack_rust():
for _ in range(500000):
s.pop()
+
+def bench_stack_native():
+ s = []
+ for i in range(500000):
+ s.append(i)
+ for _ in range(500000):
+ s.pop()
+
+
+# ==========================================
+# 2. QUEUE WORKLOADS
+# ==========================================
def bench_queue_py():
q = queue(use_rust=False)
for i in range(500000):
@@ -55,6 +84,7 @@ def bench_queue_py():
for _ in range(500000):
q.pop()
+
def bench_queue_rust():
q = queue(use_rust=True)
for i in range(500000):
@@ -62,107 +92,154 @@ def bench_queue_rust():
for _ in range(500000):
q.pop()
+
+def bench_queue_native():
+ q = deque()
+ for i in range(500000):
+ q.append(i)
+ for _ in range(500000):
+ q.popleft()
+
+
+# ==========================================
+# 3. VECTOR WORKLOADS
+# ==========================================
def bench_vector_py():
v = vector(use_rust=False)
- # 1. Push back 10,000 items
for i in range(10000):
v.push_back(i)
- # 2. Access via at()
for i in range(10000):
_ = v.at(i)
- # 3. In-place insertions
for i in range(100):
v.insert(5000, i)
+
def bench_vector_rust():
v = vector(use_rust=True)
- # 1. Push back 10,000 items
for i in range(10000):
v.push_back(i)
- # 2. Access via at()
for i in range(10000):
_ = v.at(i)
- # 3. In-place insertions
for i in range(100):
v.insert(5000, i)
+
+def bench_vector_native():
+ v = []
+ for i in range(10000):
+ v.append(i)
+ for i in range(10000):
+ _ = v[i]
+ for i in range(100):
+ v.insert(5000, i)
+
+
+# ==========================================
+# 4. SET WORKLOADS
+# ==========================================
def bench_set_py():
s = stl_set(use_rust=False)
- # 1. Insertions (50,000)
- for i in range(50000):
+ for i in range(10000):
s.insert(i)
- # 2. Lookup/Find (50,000)
- for i in range(50000):
+ for i in range(10000):
_ = s.find(i)
- # 3. Erasures (50,000)
- for i in range(50000):
+ for i in range(10000):
s.erase(i)
+
def bench_set_rust():
s = stl_set(use_rust=True)
- # 1. Insertions (50,000)
- for i in range(50000):
+ for i in range(10000):
s.insert(i)
- # 2. Lookup/Find (50,000)
- for i in range(50000):
+ for i in range(10000):
_ = s.find(i)
- # 3. Erasures (50,000)
- for i in range(50000):
+ for i in range(10000):
s.erase(i)
+
+def bench_set_native():
+ s = set()
+ for i in range(10000):
+ s.add(i)
+ for i in range(10000):
+ _ = i in s
+ for i in range(10000):
+ s.discard(i)
+
+
+# ==========================================
+# 5. MAP WORKLOADS
+# ==========================================
def bench_map_py():
m = stl_map(use_rust=False)
- # 1. Insertions (50,000)
- for i in range(50000):
+ for i in range(10000):
m.insert(i, i * 2)
- # 2. Lookup/Find (50,000)
- for i in range(50000):
+ for i in range(10000):
_ = m.find(i)
- # 3. Access via at()
- for i in range(50000):
+ for i in range(10000):
_ = m.at(i)
- # 4. Erasures (50,000)
- for i in range(50000):
+ for i in range(10000):
m.erase(i)
+
def bench_map_rust():
m = stl_map(use_rust=True)
- # 1. Insertions (50,000)
- for i in range(50000):
+ for i in range(10000):
m.insert(i, i * 2)
- # 2. Lookup/Find (50,000)
- for i in range(50000):
+ for i in range(10000):
_ = m.find(i)
- # 3. Access via at()
- for i in range(50000):
+ for i in range(10000):
_ = m.at(i)
- # 4. Erasures (50,000)
- for i in range(50000):
+ for i in range(10000):
m.erase(i)
+
+def bench_map_native():
+ m = {}
+ for i in range(10000):
+ m[i] = i * 2
+ for i in range(10000):
+ _ = i in m
+ for i in range(10000):
+ _ = m[i]
+ for i in range(10000):
+ m.pop(i, None)
+
+
+# ==========================================
+# 6. PRIORITY QUEUE WORKLOADS
+# ==========================================
def bench_priority_queue_py():
pq = priority_queue(use_rust=False)
- # 1. Pushes (50,000)
- for i in range(50000):
+ for i in range(20000):
pq.push(i)
- # 2. Pops (50,000)
- for _ in range(50000):
+ for _ in range(20000):
_ = pq.top()
pq.pop()
+
def bench_priority_queue_rust():
pq = priority_queue(use_rust=True)
- # 1. Pushes (50,000)
- for i in range(50000):
+ for i in range(20000):
pq.push(i)
- # 2. Pops (50,000)
- for _ in range(50000):
+ for _ in range(20000):
_ = pq.top()
pq.pop()
-# ----------------- Main Execution -----------------
+def bench_priority_queue_native():
+ pq = []
+ for i in range(20000):
+ heapq.heappush(pq, i)
+ for _ in range(20000):
+ if pq:
+ _ = pq[0]
+ heapq.heappop(pq)
+
+# ==========================================
+# MAIN EXECUTION
+# ==========================================
def main():
print("=============================================================")
print(" PythonSTL Comprehensive Container Benchmark Suite ")
@@ -171,37 +248,38 @@ def main():
results = {}
# 1. Stack
- py_t, rust_t = run_benchmark("Stack (500,000 cycles)", bench_stack_py, bench_stack_rust, STACK_RUST_AVAILABLE)
- results["Stack"] = (py_t, rust_t, STACK_RUST_AVAILABLE)
+ py_t, rust_t, nat_t = run_benchmark("Stack (500,000 cycles)", bench_stack_py, bench_stack_rust, STACK_RUST_AVAILABLE, bench_stack_native)
+ results["Stack"] = (py_t, rust_t, nat_t, STACK_RUST_AVAILABLE)
# 2. Queue
- py_t, rust_t = run_benchmark("Queue (500,000 cycles)", bench_queue_py, bench_queue_rust, QUEUE_RUST_AVAILABLE)
- results["Queue"] = (py_t, rust_t, QUEUE_RUST_AVAILABLE)
+ py_t, rust_t, nat_t = run_benchmark("Queue (500,000 cycles)", bench_queue_py, bench_queue_rust, QUEUE_RUST_AVAILABLE, bench_queue_native)
+ results["Queue"] = (py_t, rust_t, nat_t, QUEUE_RUST_AVAILABLE)
# 3. Vector
- py_t, rust_t = run_benchmark("Vector (10,000 push/access + 100 inserts)", bench_vector_py, bench_vector_rust, VECTOR_RUST_AVAILABLE)
- results["Vector"] = (py_t, rust_t, VECTOR_RUST_AVAILABLE)
+ py_t, rust_t, nat_t = run_benchmark("Vector (10,000 push/access + 100 inserts)", bench_vector_py, bench_vector_rust, VECTOR_RUST_AVAILABLE, bench_vector_native)
+ results["Vector"] = (py_t, rust_t, nat_t, VECTOR_RUST_AVAILABLE)
# 4. Set
- py_t, rust_t = run_benchmark("Set (50,000 inserts/finds/erases)", bench_set_py, bench_set_rust, SET_RUST_AVAILABLE)
- results["Set"] = (py_t, rust_t, SET_RUST_AVAILABLE)
+ py_t, rust_t, nat_t = run_benchmark("Set (10,000 inserts/finds/erases)", bench_set_py, bench_set_rust, SET_RUST_AVAILABLE, bench_set_native)
+ results["Set"] = (py_t, rust_t, nat_t, SET_RUST_AVAILABLE)
# 5. Map
- py_t, rust_t = run_benchmark("Map (50,000 inserts/finds/ats/erases)", bench_map_py, bench_map_rust, MAP_RUST_AVAILABLE)
- results["Map"] = (py_t, rust_t, MAP_RUST_AVAILABLE)
+ py_t, rust_t, nat_t = run_benchmark("Map (10,000 inserts/finds/ats/erases)", bench_map_py, bench_map_rust, MAP_RUST_AVAILABLE, bench_map_native)
+ results["Map"] = (py_t, rust_t, nat_t, MAP_RUST_AVAILABLE)
# 6. Priority Queue
- py_t, rust_t = run_benchmark("Priority Queue (50,000 push/pops)", bench_priority_queue_py, bench_priority_queue_rust, PQ_RUST_AVAILABLE)
- results["Priority Queue"] = (py_t, rust_t, PQ_RUST_AVAILABLE)
+ py_t, rust_t, nat_t = run_benchmark("Priority Queue (20,000 push/pops)", bench_priority_queue_py, bench_priority_queue_rust, PQ_RUST_AVAILABLE, bench_priority_queue_native)
+ results["Priority Queue"] = (py_t, rust_t, nat_t, PQ_RUST_AVAILABLE)
- print("\n" + "=" * 70)
- print(" PERFORMANCE SUMMARY TABLE ")
- print("=" * 70)
- print(f"{'Container Class':<18} | {'Pure Python':<12} | {'Python + Rust':<15} | {'Speedup Status':<18}")
- print("-" * 70)
+ print("\n" + "=" * 90)
+ print(" PERFORMANCE SUMMARY TABLE ")
+ print("=" * 90)
+ print(f"{'Container Class':<18} | {'Pure Python':<12} | {'Python + Rust':<15} | {'Native Built-in':<17} | {'Rust Speedup vs Py':<18}")
+ print("-" * 90)
- for container, (py_time, rust_time, is_rust) in results.items():
+ for container, (py_time, rust_time, native_time, is_rust) in results.items():
py_str = f"{py_time:.4f}s"
+ native_str = f"{native_time:.4f}s"
if is_rust and rust_time is not None:
rust_str = f"{rust_time:.4f}s"
speedup = py_time / rust_time
@@ -210,12 +288,14 @@ def main():
rust_str = "N/A"
status = "Pure Py Fallback"
- print(f"{container:<18} | {py_str:<12} | {rust_str:<15} | {status:<18}")
+ print(f"{container:<18} | {py_str:<12} | {rust_str:<15} | {native_str:<17} | {status:<18}")
- print("=============================================================")
- print("Note: Containers marked 'Pure Py Fallback' will run using the")
- print("original python backends until their Rust cores are built.")
- print("=============================================================")
+ print("=========================================================================================")
+ print("Note:")
+ print("1. 'Pure Python' set/map now run AVL Trees (sorted) vs C++ (BTreeSet/BTreeMap).")
+ print("2. 'Native Built-ins' (hash tables) are unsorted and perform at O(1) average case complexity.")
+ print("=========================================================================================")
+
if __name__ == "__main__":
main()
diff --git a/pyproject.toml b/pyproject.toml
index 43f2086..314c1d0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ module-name = "pythonstl._rust"
[project]
name = "pythonstl"
-version = "1.1.6"
+version = "1.1.7"
description = "C++ STL-style containers implemented in Python using the Facade Design Pattern"
readme = "README.md"
authors = [
diff --git a/pythonstl/__init__.py b/pythonstl/__init__.py
index 7986fb1..8485f8f 100644
--- a/pythonstl/__init__.py
+++ b/pythonstl/__init__.py
@@ -8,7 +8,7 @@
data structures while hiding implementation details from users.
"""
-__version__ = "1.1.6"
+__version__ = "1.1.7"
__author__ = "PySTL Contributors"
from pythonstl.facade.stack import stack
diff --git a/pythonstl/core/avl_tree.py b/pythonstl/core/avl_tree.py
new file mode 100644
index 0000000..b4ed4c2
--- /dev/null
+++ b/pythonstl/core/avl_tree.py
@@ -0,0 +1,286 @@
+"""
+AVL Tree implementation for PythonSTL.
+
+This module provides a pure-Python self-balancing Binary Search Tree (AVL Tree)
+to ensure sorted order and O(log n) operation complexity for Python fallbacks of
+associative containers (set and map), matching C++ STL and Rust BTree semantics.
+"""
+
+from typing import TypeVar, Generic, Optional, Generator, Tuple, Any
+
+K = TypeVar('K')
+V = TypeVar('V')
+
+
+class AVLNode(Generic[K, V]):
+ """A node in the AVL Tree."""
+ __slots__ = ('key', 'value', 'left', 'right', 'height')
+
+ def __init__(self, key: K, value: Optional[V] = None) -> None:
+ self.key = key
+ self.value = value
+ self.left: Optional[AVLNode[K, V]] = None
+ self.right: Optional[AVLNode[K, V]] = None
+ self.height: int = 1
+
+
+class AVLTree(Generic[K, V]):
+ """
+ A self-balancing binary search tree (AVL Tree) implementing dict/set-like operations.
+ """
+
+ def __init__(self) -> None:
+ self.root: Optional[AVLNode[K, V]] = None
+ self._size: int = 0
+
+ def __len__(self) -> int:
+ """Return the number of nodes in the tree."""
+ return self._size
+
+ def __contains__(self, key: K) -> bool:
+ """Check if a key exists in the tree."""
+ return self._find(self.root, key) is not None
+
+ def _find(self, node: Optional[AVLNode[K, V]], key: K) -> Optional[AVLNode[K, V]]:
+ curr = node
+ while curr:
+ if key == curr.key:
+ return curr
+ elif key < curr.key:
+ curr = curr.left
+ else:
+ curr = curr.right
+ return None
+
+ def __getitem__(self, key: K) -> V:
+ """Get the value associated with the key."""
+ node = self._find(self.root, key)
+ if node is None:
+ raise KeyError(key)
+ return node.value
+
+ def _get_height(self, node: Optional[AVLNode[K, V]]) -> int:
+ return node.height if node else 0
+
+ def _get_balance(self, node: Optional[AVLNode[K, V]]) -> int:
+ if not node:
+ return 0
+ return self._get_height(node.left) - self._get_height(node.right)
+
+ def _right_rotate(self, y: AVLNode[K, V]) -> AVLNode[K, V]:
+ x = y.left
+ assert x is not None
+ T2 = x.right
+
+ # Perform rotation
+ x.right = y
+ y.left = T2
+
+ # Update heights
+ y.height = max(self._get_height(y.left), self._get_height(y.right)) + 1
+ x.height = max(self._get_height(x.left), self._get_height(x.right)) + 1
+
+ return x
+
+ def _left_rotate(self, x: AVLNode[K, V]) -> AVLNode[K, V]:
+ y = x.right
+ assert y is not None
+ T2 = y.left
+
+ # Perform rotation
+ y.left = x
+ x.right = T2
+
+ # Update heights
+ x.height = max(self._get_height(x.left), self._get_height(x.right)) + 1
+ y.height = max(self._get_height(y.left), self._get_height(y.right)) + 1
+
+ return y
+
+ def __setitem__(self, key: K, value: V) -> None:
+ """Insert or update a key-value pair in the tree."""
+ self.root = self._insert(self.root, key, value)
+
+ def _insert(self, node: Optional[AVLNode[K, V]], key: K, value: V) -> AVLNode[K, V]:
+ # 1. Standard BST insertion
+ if not node:
+ self._size += 1
+ return AVLNode(key, value)
+
+ if key < node.key:
+ node.left = self._insert(node.left, key, value)
+ elif key > node.key:
+ node.right = self._insert(node.right, key, value)
+ else:
+ # Key already exists, update value and return (no size change or rebalancing needed)
+ node.value = value
+ return node
+
+ # 2. Update height of this ancestor node
+ node.height = max(self._get_height(node.left), self._get_height(node.right)) + 1
+
+ # 3. Get the balance factor
+ balance = self._get_balance(node)
+
+ # Left Left Case
+ if balance > 1 and node.left and key < node.left.key:
+ return self._right_rotate(node)
+
+ # Right Right Case
+ if balance < -1 and node.right and key > node.right.key:
+ return self._left_rotate(node)
+
+ # Left Right Case
+ if balance > 1 and node.left and key > node.left.key:
+ node.left = self._left_rotate(node.left)
+ return self._right_rotate(node)
+
+ # Right Left Case
+ if balance < -1 and node.right and key < node.right.key:
+ node.right = self._right_rotate(node.right)
+ return self._left_rotate(node)
+
+ return node
+
+ def add(self, key: K) -> None:
+ """Add a key to the tree (value defaults to None)."""
+ self[key] = None
+
+ def pop(self, key: K, default: Any = KeyError) -> Any:
+ """Remove key and return the associated value."""
+ node = self._find(self.root, key)
+ if node is None:
+ if default is not KeyError:
+ return default
+ raise KeyError(key)
+ val = node.value
+ self.root = self._delete(self.root, key)
+ return val
+
+ def discard(self, key: K) -> None:
+ """Remove key from the tree if it exists, otherwise do nothing."""
+ if key in self:
+ self.root = self._delete(self.root, key)
+
+ def _min_value_node(self, node: AVLNode[K, V]) -> AVLNode[K, V]:
+ current = node
+ while current.left:
+ current = current.left
+ return current
+
+ def _delete(self, node: Optional[AVLNode[K, V]], key: K) -> Optional[AVLNode[K, V]]:
+ if not node:
+ return node
+
+ if key < node.key:
+ node.left = self._delete(node.left, key)
+ elif key > node.key:
+ node.right = self._delete(node.right, key)
+ else:
+ # Node with only one child or no child
+ if not node.left:
+ temp = node.right
+ self._size -= 1
+ return temp
+ elif not node.right:
+ temp = node.left
+ self._size -= 1
+ return temp
+
+ # Node with two children: Get the inorder successor
+ temp = self._min_value_node(node.right)
+ node.key = temp.key
+ node.value = temp.value
+ node.right = self._delete(node.right, temp.key)
+
+ if not node:
+ return node
+
+ # Update height
+ node.height = max(self._get_height(node.left), self._get_height(node.right)) + 1
+
+ # Get balance factor
+ balance = self._get_balance(node)
+
+ # Left Left Case
+ if balance > 1 and self._get_balance(node.left) >= 0:
+ return self._right_rotate(node)
+
+ # Left Right Case
+ if balance > 1 and self._get_balance(node.left) < 0:
+ node.left = self._left_rotate(node.left)
+ return self._right_rotate(node)
+
+ # Right Right Case
+ if balance < -1 and self._get_balance(node.right) <= 0:
+ return self._left_rotate(node)
+
+ # Right Left Case
+ if balance < -1 and self._get_balance(node.right) > 0:
+ node.right = self._right_rotate(node.right)
+ return self._left_rotate(node)
+
+ return node
+
+ def __iter__(self) -> Generator[K, None, None]:
+ """In-order traversal yielding keys."""
+ yield from self._inorder(self.root)
+
+ def _inorder(self, node: Optional[AVLNode[K, V]]) -> Generator[K, None, None]:
+ if node:
+ yield from self._inorder(node.left)
+ yield node.key
+ yield from self._inorder(node.right)
+
+ def items(self) -> Generator[Tuple[K, V], None, None]:
+ """In-order traversal yielding (key, value) pairs."""
+ yield from self._inorder_items(self.root)
+
+ def _inorder_items(self, node: Optional[AVLNode[K, V]]) -> Generator[Tuple[K, V], None, None]:
+ if node:
+ yield from self._inorder_items(node.left)
+ yield (node.key, node.value)
+ yield from self._inorder_items(node.right)
+
+ def keys(self) -> Generator[K, None, None]:
+ """In-order traversal yielding keys."""
+ yield from self
+
+ def values(self) -> Generator[V, None, None]:
+ """In-order traversal yielding values."""
+ yield from self._inorder_values(self.root)
+
+ def _inorder_values(self, node: Optional[AVLNode[K, V]]) -> Generator[V, None, None]:
+ if node:
+ yield from self._inorder_values(node.left)
+ yield node.value
+ yield from self._inorder_values(node.right)
+
+ def copy(self) -> 'AVLTree[K, V]':
+ """Create a copy of the AVL tree structure."""
+ new_tree = AVLTree[K, V]()
+ new_tree.root = self._copy_node(self.root)
+ new_tree._size = self._size
+ return new_tree
+
+ def _copy_node(self, node: Optional[AVLNode[K, V]]) -> Optional[AVLNode[K, V]]:
+ if not node:
+ return None
+ new_node = AVLNode(node.key, node.value)
+ new_node.height = node.height
+ new_node.left = self._copy_node(node.left)
+ new_node.right = self._copy_node(node.right)
+ return new_node
+
+ def __eq__(self, other: object) -> bool:
+ """Check equality with another AVL tree or dict (based on element order/keys)."""
+ if isinstance(other, dict):
+ return dict(self.items()) == other
+ if not isinstance(other, AVLTree):
+ return False
+ if self._size != other._size:
+ return False
+ return list(self.items()) == list(other.items())
+
+ def __repr__(self) -> str:
+ return f"AVLTree({list(self.items())})"
diff --git a/pythonstl/implementations/associative/_map_impl.py b/pythonstl/implementations/associative/_map_impl.py
index e49b9e6..4ad4681 100644
--- a/pythonstl/implementations/associative/_map_impl.py
+++ b/pythonstl/implementations/associative/_map_impl.py
@@ -8,6 +8,7 @@
from typing import TypeVar, Dict
from pythonstl.core.exceptions import KeyNotFoundError
from pythonstl.core.iterator import MapIterator
+from pythonstl.core.avl_tree import AVLTree
K = TypeVar('K')
V = TypeVar('V')
@@ -15,7 +16,7 @@
class _MapImpl:
"""
- Internal implementation of a map using Python's built-in dict.
+ Internal implementation of a map using an AVL Tree.
This class should not be accessed directly by users.
Use the facade class `stl_map` instead.
@@ -28,7 +29,7 @@ def __init__(self) -> None:
Time Complexity:
O(1)
"""
- self._data: Dict[K, V] = {}
+ self._data: AVLTree[K, V] = AVLTree()
def insert(self, key: K, value: V) -> None:
"""
@@ -42,7 +43,7 @@ def insert(self, key: K, value: V) -> None:
If the key already exists, the value is updated.
Time Complexity:
- O(1) average case
+ O(log n)
"""
self._data[key] = value
@@ -57,9 +58,9 @@ def erase(self, key: K) -> None:
Does nothing if the key is not present (matches C++ STL behavior).
Time Complexity:
- O(1) average case
+ O(log n)
"""
- self._data.pop(key, None)
+ self._data.discard(key)
def find(self, key: K) -> bool:
"""
@@ -72,7 +73,7 @@ def find(self, key: K) -> bool:
True if the key exists, False otherwise.
Time Complexity:
- O(1) average case
+ O(log n)
"""
return key in self._data
@@ -90,11 +91,12 @@ def at(self, key: K) -> V:
KeyNotFoundError: If the key does not exist.
Time Complexity:
- O(1) average case
+ O(log n)
"""
- if key not in self._data:
+ try:
+ return self._data[key]
+ except KeyError:
raise KeyNotFoundError(key)
- return self._data[key]
def empty(self) -> bool:
"""
@@ -136,8 +138,6 @@ def end(self) -> MapIterator:
"""
Get iterator to the end of the map.
- Note: In Python dicts, end() returns an exhausted iterator.
-
Returns:
Iterator pointing past the last key-value pair.
@@ -150,15 +150,15 @@ def end(self) -> MapIterator:
def get_data(self) -> Dict[K, V]:
"""
- Get a copy of the internal data for iteration.
+ Get a copy of the internal data as a sorted dictionary.
Returns:
- Copy of the internal data dict.
+ Sorted copy of the internal data dict.
Time Complexity:
O(n) where n is the number of key-value pairs
"""
- return self._data.copy()
+ return dict(self._data.items())
__all__ = ['_MapImpl']
diff --git a/pythonstl/implementations/associative/_set_impl.py b/pythonstl/implementations/associative/_set_impl.py
index a5ff42a..ae7be9e 100644
--- a/pythonstl/implementations/associative/_set_impl.py
+++ b/pythonstl/implementations/associative/_set_impl.py
@@ -5,15 +5,16 @@
following C++ STL semantics. Users should not access this directly.
"""
-from typing import TypeVar, Set as PySet
+from typing import TypeVar, List
from pythonstl.core.iterator import SetIterator
+from pythonstl.core.avl_tree import AVLTree
T = TypeVar('T')
class _SetImpl:
"""
- Internal implementation of a set using Python's built-in set.
+ Internal implementation of a set using an AVL Tree.
This class should not be accessed directly by users.
Use the facade class `stl_set` instead.
@@ -26,7 +27,7 @@ def __init__(self) -> None:
Time Complexity:
O(1)
"""
- self._data: PySet[T] = set()
+ self._data: AVLTree[T, None] = AVLTree()
def insert(self, value: T) -> None:
"""
@@ -36,7 +37,7 @@ def insert(self, value: T) -> None:
value: The element to insert into the set.
Time Complexity:
- O(1) average case
+ O(log n)
"""
self._data.add(value)
@@ -51,7 +52,7 @@ def erase(self, value: T) -> None:
Does nothing if the element is not present (matches C++ STL behavior).
Time Complexity:
- O(1) average case
+ O(log n)
"""
self._data.discard(value)
@@ -66,7 +67,7 @@ def find(self, value: T) -> bool:
True if the element exists, False otherwise.
Time Complexity:
- O(1) average case
+ O(log n)
"""
return value in self._data
@@ -110,8 +111,6 @@ def end(self) -> SetIterator:
"""
Get iterator to the end of the set.
- Note: In Python sets, end() returns an exhausted iterator.
-
Returns:
Iterator pointing past the last element.
@@ -122,17 +121,17 @@ def end(self) -> SetIterator:
it = SetIterator(set())
return it
- def get_data(self) -> PySet[T]:
+ def get_data(self) -> List[T]:
"""
- Get a copy of the internal data for iteration.
+ Get a copy of the internal data as a sorted list.
Returns:
- Copy of the internal data set.
+ Sorted list of the internal elements.
Time Complexity:
O(n) where n is the number of elements
"""
- return self._data.copy()
+ return list(self._data)
__all__ = ['_SetImpl']
diff --git a/tests/test_map.py b/tests/test_map.py
index 1c6884c..6de4d30 100644
--- a/tests/test_map.py
+++ b/tests/test_map.py
@@ -109,3 +109,13 @@ def test_integer_keys(self):
assert m.at(1) == "one"
assert m.at(2) == "two"
assert m.find(3) is True
+
+ def test_sorted_order(self):
+ """Test that map elements are always sorted by key for both backends."""
+ for use_rust in [True, False]:
+ m = stl_map(use_rust=use_rust)
+ m.insert(30, "thirty")
+ m.insert(10, "ten")
+ m.insert(20, "twenty")
+ # Iterating should yield sorted key-value pairs
+ assert list(m) == [(10, "ten"), (20, "twenty"), (30, "thirty")]
diff --git a/tests/test_set.py b/tests/test_set.py
index 0921d85..7475fbb 100644
--- a/tests/test_set.py
+++ b/tests/test_set.py
@@ -95,3 +95,13 @@ def test_mixed_types(self):
assert s.find("banana") is True
s.erase("banana")
assert s.find("banana") is False
+
+ def test_sorted_order(self):
+ """Test that set elements are always sorted for both backends."""
+ for use_rust in [True, False]:
+ s = stl_set(use_rust=use_rust)
+ s.insert(30)
+ s.insert(10)
+ s.insert(20)
+ # Iterating should yield sorted elements
+ assert list(s) == [10, 20, 30]