How to compare/sort Python objects?
It’s easy to compare Python built-in objects, like 1 < 2, ‘abc’ < ‘abd’, tuple (0, 1, 3) < tuple (0, 2, 1). But how to compare your own objects? Say we have a Point class like this. (I’m using Python 3.6.5)
class Point:
def __init__(self, x, y):
self.x = x
self.y = yPython will complain about Point(0, 1) < Point(1, 0) since it doesn’t know how to compare.
TypeError: '<' not supported between instances of 'Point' and 'Point'Here are 3 solutions for comparing and sorting these objects.
Solution 1: Use sort() key parameter
Say we want to compare points by their distances to (0, 0) .
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
# display x and y instead of address
return f'Point(x={self.x}, y={self.y})'points = [Point(0, 1), Point(3, 0), Point(1, 0), Point(2, 0)]
points.sort(key=lambda p: p.x * p.x + p.y * p.y)
print(points)
The print result is a sorted list.
[Point(x=0, y=1), Point(x=1, y=0), Point(x=2, y=0), Point(x=3, y=0)]Both sort() (in-place) and sorted() (generates a new list) support lambda function. Parameter reverse is also very useful.
Solution 2: Override dunder method
We can define how to compare with another object by overriding = operation and one of <, ≤, >, ≥operations. Dunder is short for double underscore.
import math
from functools import total_ordering@total_ordering
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
# display x and y instead of address
return f'Point(x={self.x}, y={self.y})'
@property
def distance(self):
return math.sqrt(self.x * self.x + self.y * self.y)
def __lt__(self, other):
# p1 < p2 calls p1.__lt__(p2)
return self.distance < other.distance
def __eq__(self, other):
# p1 == p2 calls p1.__eq__(p2)
return self.distance == other.distanceprint(sorted([
Point(0, 1), Point(3, 0), Point(1, 0), Point(2, 0)
]))
We can get the same sorted list.
Solution 3: Convert to built-in objects
Use a tuple (distance_square, index, point) and Python knows how to compare tuples and integers (distance_square). index is needed to retain the relative position for two points with the same distance, i.e. make the sorting stable.
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
# display x and y instead of address
return f'Point(x={self.x}, y={self.y})'points = [Point(0, 1), Point(3, 0), Point(1, 0), Point(2, 0)]
tuples = [(p.x * p.x + p.y * p.y, i, p)
for i, p in enumerate(points)]
print(tuples)tuples.sort()
print(tuples)
The result is a list of sorted tuples, and probably you need another line to extract those points.
[(1, 0, Point(x=0, y=1)), (9, 1, Point(x=3, y=0)), (1, 2, Point(x=1, y=0)), (4, 3, Point(x=2, y=0))][(1, 0, Point(x=0, y=1)), (1, 2, Point(x=1, y=0)), (4, 3, Point(x=2, y=0)), (9, 1, Point(x=3, y=0))]
Conclusion
Solution 1’s code is the least and less code means fewer bugs. Solution 2 is the most flexible one since the comparison rule is inside the class. With Solution 2, we can check Point(0, 1) < Point(3, 0) . Solution 3 takes extra space, but it’s very useful for Priority Queue since a typical entry is a tuple (priority number, data) .