Skip to main content

PyO3 Advanced: Custom Python Protocols and Traits

Python's richness comes from protocols—informal contracts defined by dunder methods (__len__, __iter__, __enter__). Implementing these on your PyO3 classes makes them feel native to Python and unlocks syntactic sugar like for item in obj:, len(obj), and with obj:. This article teaches you to implement the most useful protocols: container (__len__, __getitem__), iterator (__iter__, __next__), context manager (__enter__, __exit__), and comparison (__eq__, __lt__). By the end, you will author Rust classes that integrate seamlessly with Python idioms.

Protocols are not magic; they are just method names Python knows to call. Implementing them correctly requires understanding when Python calls them, what they receive, and what they must return. This knowledge bridges the gap between "working" and "Pythonic."

Container Protocol: __len__ and __getitem__

A container lets you query its size and access elements by index. Implement __len__ (called by len(obj)) and __getitem__ (called by obj[i]):

use pyo3::prelude::*;

#[pyclass]
struct Vector {
data: Vec<f64>,
}

#[pymethods]
impl Vector {
#[new]
fn new(data: Vec<f64>) -> Self {
Vector { data }
}

fn __len__(&self) -> usize {
self.data.len()
}

fn __getitem__(&self, idx: isize) -> PyResult<f64> {
let index = if idx < 0 {
(self.data.len() as isize + idx) as usize
} else {
idx as usize
};
self.data.get(index)
.copied()
.ok_or_else(|| PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"index out of range"
))
}

fn __setitem__(&mut self, idx: isize, value: f64) -> PyResult<()> {
let index = if idx < 0 {
(self.data.len() as isize + idx) as usize
} else {
idx as usize
};
self.data.get_mut(index)
.map(|slot| *slot = value)
.ok_or_else(|| PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"index out of range"
))
}
}

#[pymodule]
fn vector_ext(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Vector>()?;
Ok(())
}

From Python:

from vector_ext import Vector

v = Vector([1.0, 2.0, 3.0])
print(len(v)) # Output: 3
print(v[0]) # Output: 1.0
print(v[-1]) # Output: 3.0 (negative indexing)
v[1] = 5.0
print(v[1]) # Output: 5.0

The isize type accommodates both positive and negative indices (Python's convention). Handle negative indices by adding them to the length.

Iterator Protocol: __iter__ and __next__

Iterators allow for item in obj: syntax. Implement __iter__ (returns the iterator) and __next__ (returns the next item or raises StopIteration):

use pyo3::prelude::*;

#[pyclass]
struct Counter {
count: usize,
max: usize,
}

#[pymethods]
impl Counter {
#[new]
fn new(max: usize) -> Self {
Counter { count: 0, max }
}

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf // The iterator is the object itself
}

fn __next__(&mut self) -> PyResult<usize> {
if self.count < self.max {
self.count += 1;
Ok(self.count - 1)
} else {
Err(PyErr::new::<pyo3::exceptions::PyStopIteration, _>("done"))
}
}
}

#[pymodule]
fn counter_ext(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Counter>()?;
Ok(())
}

From Python:

from counter_ext import Counter

for i in Counter(3):
print(i)
# Output: 0, 1, 2

# Or explicitly:
c = Counter(3)
print(next(c)) # Output: 0
print(next(c)) # Output: 1

Note: In PyO3, raising PyStopIteration signals the end of iteration to Python, which translates it to StopIteration and terminates the loop.

Context Manager Protocol: __enter__ and __exit__

Context managers enable the with statement. __enter__ is called on entry, and __exit__ on exit (even if an exception occurs):

use pyo3::prelude::*;

#[pyclass]
struct FileWrapper {
path: String,
#[pyo3(get)]
is_open: bool,
}

#[pymethods]
impl FileWrapper {
#[new]
fn new(path: String) -> Self {
FileWrapper {
path,
is_open: false,
}
}

fn __enter__(&mut self) -> PyResult<String> {
self.is_open = true;
Ok(format!("Opened {}", self.path))
}

fn __exit__(
&mut self,
_exc_type: Option<&Bound<'_, PyType>>,
_exc_val: Option<&Bound<'_, PyAny>>,
_exc_tb: Option<&Bound<'_, PyAny>>,
) -> PyResult<bool> {
self.is_open = false;
println!("Closed {}", self.path);
Ok(false) // False = do not suppress exceptions
}
}

#[pymodule]
fn context_ext(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<FileWrapper>()?;
Ok(())
}

From Python:

from context_ext import FileWrapper

with FileWrapper("data.txt") as f:
print(f) # Output: Opened data.txt
# Output: Closed data.txt

__exit__ receives the exception type, value, and traceback if an exception occurred. Returning False (or Ok(false) in Rust) allows the exception to propagate; returning True suppresses it.

Comparison Protocol: __eq__, __lt__, __le__, __gt__, __ge__

Rich comparison methods enable sorting and equality checks:

use pyo3::prelude::*;

#[pyclass]
#[derive(Clone)]
struct Point {
x: f64,
y: f64,
}

#[pymethods]
impl Point {
#[new]
fn new(x: f64, y: f64) -> Self {
Point { x, y }
}

fn __eq__(&self, other: &Point) -> bool {
self.x == other.x && self.y == other.y
}

fn __lt__(&self, other: &Point) -> bool {
// Compare by distance from origin
let dist_self = (self.x.powi(2) + self.y.powi(2)).sqrt();
let dist_other = (other.x.powi(2) + other.y.powi(2)).sqrt();
dist_self < dist_other
}

fn __repr__(&self) -> String {
format!("Point({}, {})", self.x, self.y)
}
}

#[pymodule]
fn point_ext(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Point>()?;
Ok(())
}

From Python:

from point_ext import Point

p1 = Point(1.0, 1.0)
p2 = Point(2.0, 2.0)

print(p1 == p1) # Output: True
print(p1 == p2) # Output: False
print(p1 < p2) # Output: True (closer to origin)

points = [Point(3.0, 4.0), Point(1.0, 0.0), Point(0.0, 2.0)]
sorted_points = sorted(points) # Uses __lt__
for p in sorted_points:
print(p)

Reflection Protocol: __str__, __repr__, __hash__

String representations and hashing are essential for usability:

#[pymethods]
impl Point {
fn __str__(&self) -> String {
format!("({}, {})", self.x, self.y)
}

fn __repr__(&self) -> String {
format!("Point({}, {})", self.x, self.y)
}

fn __hash__(&self) -> u64 {
// Simple hash based on coordinates
let x_bits = self.x.to_bits();
let y_bits = self.y.to_bits();
(x_bits as u64) ^ (y_bits as u64).wrapping_mul(31)
}
}

From Python:

p = Point(1.0, 2.0)
print(str(p)) # Output: (1.0, 2.0)
print(repr(p)) # Output: Point(1.0, 2.0)
print(hash(p)) # Output: (some integer hash)

# Use in sets and dicts
points_set = {p}

Protocol Mapping Reference

Python ProtocolRust MethodsUse Case
Container__len__, __getitem__, __setitem__Indexing, slicing, length queries
Iterator__iter__, __next__for loops
Context Manager__enter__, __exit__with statements
Comparison__eq__, __lt__, __le__, __gt__, __ge__Sorting, equality checks
Reflection__str__, __repr__, __hash__String representation, hashing
Callable__call__obj() syntax
Numeric__add__, __mul__, __truediv__Operator overloading

Key Takeaways

  • Dunder methods define Python protocols; implementing them makes Rust classes feel native.
  • __len__ and __getitem__ enable indexing; __setitem__ enables mutation.
  • __iter__ and __next__ enable for loops; raise PyStopIteration to signal end.
  • __enter__ and __exit__ enable with statements; __exit__ must handle exceptions.
  • Rich comparison methods (__eq__, __lt__) enable sorting and equality.
  • Dunder methods have strict signatures; incorrect signatures cause runtime errors.

Frequently Asked Questions

What if I only implement __eq__ but not __lt__?

That is fine. Implement only the methods your use case needs. If users call a missing method, PyO3 raises TypeError.

Can I implement __call__ to make my class callable (e.g., obj())?

Yes. Add fn __call__(&self, args: ...) -> PyResult<T> to make instances callable like functions.

How do I implement __contains__ for membership testing (item in obj)?

Add fn __contains__(&self, item: T) -> bool. PyO3 calls this when you use the in operator.

What is the difference between __str__ and __repr__?

__str__ is a user-friendly string (called by str() and print()). __repr__ is a developer-friendly representation (called by repr() and in the REPL). Convention: __repr__ should be unambiguous; __str__ should be readable. If you implement only one, implement __repr__.

Can I use #[derive(...)] to auto-implement some dunder methods?

PyO3 provides #[derive(PartialEq)] (for __eq__) and #[derive(Hash)] (for __hash__) from Rust's standard library. However, PyO3 requires explicit implementations via #[pymethods] to expose them to Python. Use Rust's derive for internal purposes; implement dunder methods in #[pymethods] for Python exposure.

Further Reading