Classes Intermediate: Operator Overloading

A place where you can post Python-related tutorials you made yourself, or links to tutorials made by others.

Classes Intermediate: Operator Overloading

Postby ichabod801 » Mon Mar 04, 2013 10:31 pm

Before reading this tutorial, make sure you understand everything covered in the Class Basics tutorial.

Operator overloading is a way to make the classes you create work with the standard Python operators and some of the built-in functions. As a side effect of this it makes you classes work well with other parts of Python that make use of those operators and bulit-in functions.

To illustrate some examples, I am going to use a Vector class:

Code: Select all
class Vector(object):
   """
   Vector
   
   A two dimensional bound Cartesian vector.
   
   Attributes:
   x: The x coordinate of the vector terminus (int or float)
   y: The y coordinate of the vector terminus (int or float)
   
   Overridden Methods:
   __init__
   """

   def __init__(self, x, y = 0):
      """
      Initialize a vector from a tuple or pair of numbers.
      
      Parameters:
      x: The x coordinate or a sequence of x and y coordinates (sequence or number)
      y: The y coordinate (number)
      """
      # check for sequence initialization
      if isinstance(x, (tuple, list)):
         self.x = x[0]
         self.y = x[1]
      # otherwise use number initialization
      else:
         self.x = x
         self.y = y


You may think there's a lot of commenting going on here. It's useful, though. Run the above code and type "help(Vector)". It will spew out all of the block comments for the class. This is useful when you are trying to debug, especially when trying to debug code you wrote three years ago that your boss asked you to add a new feature to. As a way of introducing our first overlaoded operator, Imagine you are trying to debug a problem with a particular vector:

Code: Select all
>>> a = Vector(8, 1)
>>> a
<__main__.Vector object at 0x02403430>


Well, that's not very helpful, is it? We could get around this with:

Code: Select all
>>> a = Vector(8, 1)
>>> print(a.x, a.y)
8 1


That's a pain in the butt, and usually when you're debugging you have a pain in the butt already. The key to getting around this is knowing that typing 'a' into the interpreter is the same as typing 'repr(a)' into the interpreter, and typing 'print(a)' into the interpreter is the same as typing 'str(a)' into the interpreter. We can use operator overloading to make Python do what we want in those situations:

Code: Select all
   def __repr__(self):
      """
      Computer readable representation.
      """
      return 'Vector({}, {})'.format(self.x, self.y)
      
   def __str__(self):
      """
      Human readable representation.
      """
      return '({}, {})'.format(self.x, self.y)


Add the above code to your Vector class definition. Now we get:

Code: Select all
>>> a = Vector(8, 1)
>>> a
Vector(8, 1)
>>> print(a)
(8, 1)


Much better. But you may be wondering why I made the easier to type version (a) return a more complicated result. And why did I call it "computer readable?" The convention for repr is that eval(repr(x)) == x. So the return value of __repr__ should be something that evaluates into an equivalent object. For complicated objects that change a lot after they are created, this can be a pain. In that case the convention is to put the object type and some other useful information inside angle brackets. Which is exactly what we saw before we implemented __repr__ for the Vector class.

So let's try out our new repr:

Code: Select all
>>> a = Vector(8, 1)
>>> eval(repr(a)) == a
False


So why didn't that work the way I said it would? Because we haven't told Python how to judge the equality of Vector objects. == is an operator after all, and it's one we haven't overloaded yet. Python, not having been told how to judge equality of Vectors, judges it by the memory address of the objects. eval(repr(a)) resides at a different memory location than a itself does. So let's tell Python how to judge equality:

Code: Select all
   def __eq__(self, other):
      """
      Equality testing.
      
      Parameters:
      other: What to check equality against. (Vector)
      """
      return self.x == other.x and self.y == other.y


Now it works:

Code: Select all
>>> a = Vector(8, 1)
>>> eval(repr(a)) == a
True
>>> b = Vector(8, 1)
>>> a == b
True
>>> c = Vector(8, 2)
>>> a == c
False
>>> a == (8, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "vector.py", line 39, in __eq__
    return self.x == other.x and self.y == other.y
AttributeError: 'tuple' object has no attribute 'x'


That last AttributeError is acutally bad form. The convention in Python is to return the special object NotImplemented if the equality comparison is not supported. On the other hand, we might want to support equality with tuples. So let's rewrite our __eq__ method:

Code: Select all
   def __eq__(self, other):
      """
      Equality testing.
      
      Parameters:
      other: What to check equality against. (Vector)
      """
      # vector to vector
      if isinstance(other, Vector):
         return self.x == other.x and self.y == other.y
      # vector to sequence
      elif isinstance(other, (tuple, list)):
         return self.x == other[0] and self.y == other[1]
      # vector to anything else
      else:
         return NotImplemented


And when we try it out:

Code: Select all
>>> a = Vector(8, 1)
>>> b = Vector(8, 1)
>>> a == b
True
>>> a == (8, 1)
True
>>> b == [8, 1]
True
>>> a == 5
False
>>> a == (8, 1, -1)
True


Now when we try to equal something that doesn't make sense, Python gracefully gives a False instead of an error. And note that we may have wanted that last comparison to be False. I'm trying to be simple here, but it is good to be careful what you code lest you get unexpected results.
Craig "Ichabod" O'Brien
Minimalist, buddhist, theist, and programmer
Current languages: Python, SAS, and C++
Previous serious languages: R, Java, VBA, Lisp, HyperTalk, BASIC
ichabod801
 
Posts: 84
Joined: Sat Feb 09, 2013 12:54 pm
Location: Outside Washington DC

Re: Classes Intermediate: Operator Overloading

Postby ichabod801 » Mon Mar 04, 2013 10:32 pm

Just as we can override ==, we can override the other boolean comparisons:

Code: Select all
__ge__ overrides >=
__gt__ overrides >
__le__ overrides <=
__lt__ overrides <
__ne__ overrides !=


Now let's say we had rewritten our __eq__ method to handle tuples of length 3 and some other odd cases. At that point our __eq__ method is starting to get rather complicated. Rewriting that code five more times to cover all the comparison operators is not only a pain, but it's a perfect opportunity to introduce bugs. Fortunately, the functools module provides a short cut named total_ordering. total_ordering is a class decorator. I'm not going to get into the details of decorators here, but basically they precede a definition and modify that definition. For our purposes we add a few lines to the start of our code:

Code: Select all
import functools

@functools.total_ordering
class Vector(object):
   ...


Now, total_ordering requires us to have defined __eq__ and one of __ge__, __gt__, __le__, __lt__. total_ordering then uses those two methods to define all of the other ones. So we need another method, I usually use __lt__:

Code: Select all
   def __lt__(self, other):
      """
      Less than testing.
      
      Parameters:
      other: What to check less than against. (Vector)
      """
      # vector to vector
      if isinstance(other, Vector):
         return self.length() < other.length()
      # vector to sequence
      elif isinstance(other, (tuple, list)):
         return self.length() < vector_length(other)
      # vector to anything else
      else:
         return NotImplemented
      
   def length(self):
      """
      The length of the vector
      """
      return (self.x ** 2 + self.y ** 2) ** 0.5
      
def vector_length(vector):
   """
   Calculate the vector length for a list or tuple.
   
   Parameters:
   vector: A sequence of at least two numbers (list or tuple)
   """
   return (vector[0] ** 2 + vector[1] ** 2) ** 0.5


That's actually two more methods and a function external to the class. I wanted to order the vectors by length, but we're going to want to know the vector's length in other situations. So we might as well define a method for it. Since __lt__ uses length, it only has to be changed in one place if it needs to be changed. And if we are going to compare to sequences, we will need a way to calculate a length for them as well.

Code: Select all
>>> a = Vector(8, 1)
>>> b = Vector(3, 4)
>>> a < b
False
>>> b < a
True
>>> a >= b
True
>>> a != b
True
>>> c = Vector(3, -4)
>>> b < c
False
>>> b > c
True
>>> c < b
False


WTF? How can b be greater than c, but c not be less than b!? The problem is that we defined __eq__ and __lt__ in different terms. Our __eq__ method is effectively comparing magnitude and direction, while __lt__ is only testing magnitude. When total_ordering tries to combine them all, it makes some screwey results. This is the classic trap of operator overloading: doing things that don't really make sense. We don't generally order bound Cartesian vectors by magnitude, and forcing that on the system screwed things up.

Let's take another case that is perhaps more subtle. len(x) can be overridden with x.__len__. Why did I write a length method for the Vector class instead of using __len__? Because "len" in Python isn't "length" in mathematics. In Python, len is for containers, such as lists, dictionaries, and sets, and tells how many items are in the container. So __len__ for our Vector class isn't really meaningful, and if we really wanted to do it, __len__ would always return 2 for the two dimensions.
Craig "Ichabod" O'Brien
Minimalist, buddhist, theist, and programmer
Current languages: Python, SAS, and C++
Previous serious languages: R, Java, VBA, Lisp, HyperTalk, BASIC
ichabod801
 
Posts: 84
Joined: Sat Feb 09, 2013 12:54 pm
Location: Outside Washington DC

Re: Classes Intermediate: Operator Overloading

Postby ichabod801 » Mon Mar 04, 2013 10:34 pm

Okay, then what operations do we normally do on vectors? We normally add them together and multiply them by scalars. Just as the comparison operators have methods allowing you to override them, so do addition, multiplication, and so on:

Code: Select all
__add__ overrides +
__sub__ overrides -
__mul__ overrides *
__truediv__ overrides /
__floordiv__ overrides //
__mod__ overrides %
__pow__ overrides **
__lshift__ overrides <<
__rshift__ overrides >>
__and__ overrides &
__xor__ overrides ^
__or__ overrides |


Note that __truediv__ is not supported in Python 2.x unless you have imported division from future. In 2.x, use __div__ for /, but realize that __div__ doesn't work in Python 3.x. If you are trying to develop for both, implement __div__ and __truediv__.

But we're just going to implement __add__ and __mul__ for our Vector class:

Code: Select all
   def __add__(self, other):
      """
      Addition.
      
      Parameters:
      other: the other summand. (Vector, list, or tuple)
      """
      if isinstance(other, Vector):
         return Vector(self.x + other.x, self.y + other.y)
      elif isinstance(other, (tuple, list)):
         return Vector(self.x + other[0], self.y + other[1])
      else:
         return NotImplemented
   
   def __mul__(self, other):
      """
      Scalar multiplication.
      
      Parameters:
      other: The multiplier (float or int)
      """
      # vector * number
      if isinstance(other, (float, int)):
         return Vector(other * self.x, other * self.y)
      # vector * other
      else:
         return NotImplemented


When we try it out we get:

Code: Select all
>>> a = Vector(8, 1)
>>> b = Vector(3, 4)
>>> a + b
Vector(11, 5)
>>> b + a
Vector(11, 5)
>>> a + b == b + a
True
>>> a + (3, 4)
Vector(11, 5)
>>> 5 * a
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unsupported operand type(s) for *: 'int' and 'Vector'


So why didn't the multiplication work? Here's a clue:

Code: Select all
>>> (3, 4) + a
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can only concatenate tuple (not "Vector") to tuple
>>> a * 5
Vector(40, 5)


So __add__ and __mul__ are only working when the vector is on the left side. Here's what is going on: When Python sees 'spam + eggs', it tries 'spam.__add__(eggs)'. If that returns NotImplemented, it tries 'eggs.__radd__(spam)' (radd for right addition). Now, if all you are ever going to do is add Vectors to Vectors, you just need __add__. But when you start mixing in tuples and lists you have to accont for __radd__ as well. If you were always going to have the scalar on the left, as is standard in mathematics, all you would ever need is __rmul__. But then you'd run the risk of someone typing in 'a * 5', so it's best to implement both.

However, it's easy to do so:

Code: Select all
   def __radd__(self, other):
      """
      Right addition.
      
      Parameters:
      other: The other summand (Vector, tuple, or list)
      """
      return self.__add__(other)
         
   def __rmul__(self, other):
      """
      Right multiplication.
      
      Parameters:
      other: The other multiplier (float or int)
      """
      return self.__mul__(other)


There may be objects for which a + b != b + a, but that is not the case with Vectors. So we can just use the functionality we already wrote for __add__ and __mul__ to handle the right hand cases. Again, that means that if we want to change that functionality, we only need to change it in one place. Now things work as expected:

Code: Select all
>>> a = Vector(8, 1)
>>> (3, 4) + a
Vector(11, 5)
>>> 5 * a
Vector(40, 5)
>>> 3 * a == a + a + a
True
>>> a * (3, 4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can't multiply sequence by non-int of type 'Vector'


We get the last TypeError because we didn't implement multiplication by tuples, lists, or Vectors. But that's the same TypeError the other errors were causing, so we are being Pythonic by causing the expected error.
Craig "Ichabod" O'Brien
Minimalist, buddhist, theist, and programmer
Current languages: Python, SAS, and C++
Previous serious languages: R, Java, VBA, Lisp, HyperTalk, BASIC
ichabod801
 
Posts: 84
Joined: Sat Feb 09, 2013 12:54 pm
Location: Outside Washington DC

Re: Classes Intermediate: Operator Overloading

Postby ichabod801 » Mon Mar 04, 2013 10:35 pm

So what else can we override? Lots of stuff. I'm going to assume you have a grasp of the basic idea of operator overloading, and talk about some of the other special methods for operator overloading. The ones I'm not going to talk about get into the nuts and bolts of Python. If you mess with them without knowing what you are doing, you can totally screw things up. If you want to screw things up, or want more details on the special methods I've talked about, go to the Data Model section of the Language Reference section of the documentation for your version of Python.

In addition to __repr__ and __str__ there are also __bytes__ and __format__. __bytes__ is called when bytes(a) is used to make a byte-string representation of an object, and should return a bytes object. __format__ is called when string formatting (''.format()) is called. It receives a format specification as a parameter, and allows you to customize how string formatting is performed on your object. For example, if some numeric formatting specifying a number of digits was passed to __format__, you could apply that formatting to the x and y attributes of your Vector object.

__hash__ is called by the built-in function hash(). It should return an integer, and if two instances of your class evaluate as equal, they must return the same integer. This is very useful, because implementing it allows your objects to be used in sets and frozensets, and as keys in dicts. A good way to use it in the Vector class would be to return hash((x, y)), making use of the tuple class's hashing function.

__bool__ is called whenever your object is used for truth testing. For Vector we might say it returns False for (0, 0), and True for everything else. If __bool__ is not implemented, __len__ is called, and the object is True if __len__ returns a non-zero value. If neither is implemented, True is returned.

One of my favorite special methods is __call__. Don't ask me why. It allows instances to act like functions. For example, say we implemented __call__ for our Vector class. Then we made an instance with 'a = Vector(8, 1)'. Then 'a('spam', 'eggs')' would call 'a.__call__(self, 'spam', 'eggs')'.

There are several methods for making objects that act as containers. When doing that you should also add the methods the container type you are emulating generally has, such as append and index for sequences and get for dictionaries. Also take a look at the collections modules, it provides some base classes that are useful for emulating containers. The basics are (assuming an instance named foo):

Code: Select all
__len__ overrides len(foo)
__getitem__(key) overrides foo[key]
__setitem__(key, value) overrides foo[key] = value
__delitem__(key) overrides del foo[key]
__iter__() overrides making an iterator of the object
__reversed__() overrides reversed(foo)
__contains__(value) overrides value in foo


For making things that act like numbers I skipped over __divmod__, which overrides the divmod() built-in function. Also note that __pow__ overrides both a ** x and the pow built-in function. In addition to the __add__ and __radd__ type methods for mathematical operators, there is also __iadd__, __isub__, __imul__, and so on. These override augmented assignements such as +=, -=, and *=. If you don't define them, Python just converts 'a += b' to 'a = a + b', but this may not be the functionality you are looking for. For unary +, -, and ~ use the __neg__, __pos__, and __inv__ methods. For overriding abs() and round() use __abs__ and __round__. For converting to numbers, use __int__, __float__, __complex__.

I hope this massive summary impresses upon you can do with operator overloading, and how much you can mess up with operator overloading. So go forth and be careful.

And remember that it's not my fault. ;)
Craig "Ichabod" O'Brien
Minimalist, buddhist, theist, and programmer
Current languages: Python, SAS, and C++
Previous serious languages: R, Java, VBA, Lisp, HyperTalk, BASIC
ichabod801
 
Posts: 84
Joined: Sat Feb 09, 2013 12:54 pm
Location: Outside Washington DC


Return to Tutorials

Who is online

Users browsing this forum: No registered users and 2 guests

cron