Visitor pattern done right

Ever since I've learned to write programs in Haskell using algebraic data types, I was yearning for a similar feature in more conventional programming languages that would be as convenient and type-safe to use without sacrificing support from tooling built around the language in question, and I've recently found a satisfactory approach.

The approach is, perhaps unsurprisingly, is a slight extension of Visitor pattern, that I'm going to illustrate here on an example of implementing naive binary search tree. Binary search tree can be either an empty leaf or an internal node with a key and two branch nodes, and the encoding using typical Visitor pattern is going to look like this:

interface Tree {
    interface Visitor<T> {
        T visitLeaf();
        T visitFork(Tree left, int mid, Tree right);
    }

    <T> T accept(Visitor<T> visitor);
}

class Leaf implements Tree {
    @Override
    public <T> T accept(Visitor<T> visitor) {
        return visitor.visitLeaf();
    }
}

class Fork implements Tree {
    private final Tree left, right;
    private final int mid;

    Fork(Tree left, int mid, Tree right) {
        this.left = left;
        this.mid = mid;
        this.right = right;
    }

    @Override
    public <T> T accept(Visitor<T> visitor) {
        return visitor.visitFork(left, mid, right);
    }
}

Obviously, for larger examples and repetitive definition of similar structures one would want to use a code generator. Let's say that the workflow with the code generator would look like this:

Let's take a look at the usage of this code in a simple program that implements insertion into naive binary tree, as well as conversion to list:

static Tree insert(Tree tree, int key) {
    return tree.accept(new Tree.Visitor<Tree>() {
        @Override
        public Tree visitLeaf() {
            return new Fork(new Leaf(), key, new Leaf());
        }

        @Override
        public Tree visitFork(Tree left, int mid, Tree right) {
            if (key < mid) {
                return new Fork(insert(left, key), mid, right);
            } else if (key > mid) {
                return new Fork(left, mid, insert(right, key));
            } else {
                return tree;
            }
        }
    });
}

static List<Integer> toList(Tree tree) {
    List<Integer> result = new ArrayList<>();

    tree.accept(new Tree.Visitor<Void>() {
        @Override
        public Void visitLeaf() {
            return null;
        }

        @Override
        public Void visitFork(Tree left, int mid, Tree right) {
            left.accept(this);
            result.add(mid);
            right.accept(this);

            return null;
        }
    });

    return result;
}

The most important point is that code directly uses constructors of Fork and Leaf, meaning that if later someone would like to change name of Fork to, say, Node, they wouldn't be able to use a simple refactoring for that. Just renaming visitFork to visitNode would lead to code generator generating a class with new name on next launch, leaving unresolved references to non-existing class. So, a renaming for the Fork class would have to be performed as well.

My insight here is as followws: instantiating Visitor with Tree itself we can get an interface for a factory of Trees. This way, construction and case analysis are going to be using the same symbols defined in code. Here's how it's going to look like:

interface Tree {
    interface Visitor<T> {
        T leaf();
        T fork(Tree left, int mid, Tree right);
    }

    <T> T accept(Visitor<T> visitor);
}

enum TreeFactory implements Tree.Visitor<Tree> {
    TreeFactory;

    @Override
    public Tree leaf() {
        return new Tree() {
            @Override
            public <T> T accept(Visitor<T> visitor) {
                return visitor.leaf();
            }
        };
    }

    @Override
    public Tree fork(Tree left, int mid, Tree right) {
        return new Tree() {
            @Override
            public <T> T accept(Visitor<T> visitor) {
                return visitor.fork(left, mid, right);
            }
        };
    }
}

Notice that I've changed names visitLeaf and visitFork to leaf and fork respectively. This is done for more "natural" names of factory methods to construct values of Tree (compare TreeFactory.leaf with TreeFactory.visitLeaf). Here's how usage of this new version changes:

static Tree insert(Tree tree, int key) {
    return tree.accept(new Tree.Visitor<Tree>() {
        @Override
        public Tree leaf() {
            return TreeFactory.fork(TreeFactory.leaf(), key, TreeFactory.leaf());
        }

        @Override
        public Tree fork(Tree left, int mid, Tree right) {
            if (key < mid) {
                return TreeFactory.fork(insert(left, key), mid, right);
            } else if (key > mid) {
                return TreeFactory.fork(left, mid, insert(right, key));
            } else {
                return TreeFactory.fork(left, mid, right);
            }
        }
    });
}

static List<Integer> toList(Tree tree) {
    List<Integer> result = new ArrayList<>();

    tree.accept(new Tree.Visitor<Void>() {
        @Override
        public Void leaf() {
            return null;
        }

        @Override
        public Void fork(Tree left, int mid, Tree right) {
            left.accept(this);
            result.add(mid);
            right.accept(this);

            return null;
        }
    });

    return result;
}

The punchline here is that same method names (leaf and fork) are used both to construct values of type Tree as well as perform case analysis. In order to perform a renaming, one would have to rename only one symbol. While I understand that this might not look like a big deal to you, it is to me, since I can rely on IDE to perform correct program transformations without thinking about how the code generator used in project is implemented (less thinking = less opportunities to fuck something up by accident).