DIY tagged unions with exhaustiveness checking in Python

Here is something I came up with today while my mind was drifting away while reading a book: there is a simple way to implement tagged unions in Python, while allowing exhaustiveness checking and keeping passable syntax.

Without further ado, here is how matching result would look like for a function that implements bind (aka flatMap) for Haskell-like Either type:

def either_bind(x, fn):
    with either_handler(x) as h:
        for value in h.handle_tag('right'):
            h.keep_value(fn(value))
        for error in h.handle_tag('left'):
            h.keep_value(left(error))
        return h.get_value()

Commenting out second handler in the function would produce following error message:

>>> either_bind(subtract_one(2), subtract_one)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "cursed.py", line 65, in either_bind
    return h.get_value()
  File "cursed.py", line 30, in __exit__
    raise Exception("forgot to handle tag %s!" % tag)
Exception: forgot to handle tag left!

Reader familiar with the subject of discussion is encouraged to figure out the implentation by themselves (which I assume should be possible from this snippet). The main ideas are as follows:

  1. Use the handler state to keep tags that already have been seen.
  2. Implement __exit__ method that would verify all possible options being handled.
  3. Return either empty or singleton lists of values from handle_tag method.

Here is the complete listing in case you would like to play with it but don't want to rewrite from scratch though:

from collections import namedtuple

Tagged = namedtuple('Tagged', ['tag', 'data'])

class TaggedHandler():
    def __init__(self, all_tags, value):
        self.all_tags = all_tags
        self.value = value
        self.seen_tags = []
        self.result = None

    def __enter__(self):
        return self

    def handle_tag(self, tag):
        if tag not in self.all_tags:
            raise Exception("tag %s is none of %s" % (tag, self.all_tags))
        if tag in self.seen_tags:
            raise Exception("trying to handle tag %s twice" % tag)
        self.seen_tags.append(tag)

        if tag == self.value.tag:
            return [self.value.data]
        else:
            return []

    def __exit__(self, dunno1, dunno2, dunno3):
        for tag in self.all_tags:
            if tag not in self.seen_tags:
                raise Exception("forgot to handle tag %s!" % tag)

    def keep_value(self, value):
        if self.result is not None:
            raise Exception("trying to keep more than one result!")

        self.result = value

    def get_value(self):
        if self.result is None:
            raise Exception("no saved value!")

        return self.result

def left(error):
    return Tagged('left', error)

def right(value):
    return Tagged('right', value)

def either_handler(value):
    return TaggedHandler(['left', 'right'], value)

With a sprinkle of metaprogramming this can be used to define full tagged union types quite succinctly and not as ugly, but this is left as an exercise for the reader as well.