⁂ George Ho

Adventures in Manipulating Python ASTs

A while back, I explored the possibility of simplifying 1 PyMC4’s model specification API by manipulating the Python abstract syntax tree (AST) of the model code. The PyMC developers didn’t end up pursuing those API changes any further, but not until I had the chance to learn a lot about Python ASTs.

Enough curious people have asked me about my experience tinkering with ASTs that I figure I’d write a short post about the details of my project, in the hope that someone else will find it useful.

You should read this blog post as a quick overview of my experience with Python ASTs, or an annotated list of links, and not a comprehensive tutorial on model specification APIs or Python ASTs. For a full paper trail of my adventures with Python ASTs, check out my notebooks on GitHub.

The Problem

Originally, PyMC4’s proposed model specification API looked something like this:

def linear_regression(x):
    scale = yield tfd.HalfCauchy(0, 1)
    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1)
    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale)
    return predictions

The main drawback to this API was that the yield keyword was confusing. Many users don’t really understand Python generators, and those who do might only understand yield as a drop-in replacement for return (that is, they might understand what it means for a function to end in yield foo, but would be uncomfortable with bar = yield foo).

Furthermore, the yield keyword introduces a leaky abstraction2: users don’t care about whether model is a function or a generator, and they shouldn’t need to. More generally, users shouldn’t have to know anything about how PyMC works in order to use it: ideally, the only thing users would need to think about would be their data and their model. Having to graft several yield keywords into their code is a fairly big intrusion in that respect.

Finally, this model specification API is essentially moving the problem off of our plates and onto our users. The entire point of the PyMC project is to provide a friendly and easy-to-use interface for Bayesian modelling.

To enumerate the problem further, we wanted to:

  1. Hide the yield keyword from the user-facing model specification API.
  2. Obtain the user-defined model as a generator.

The main difficulty with the first goal is that as soon as we remove yield from the model function, it is no longer a generator. However, the PyMC inference engine needs the model as a generator, since this allows us to interrupt the control flow of the model at various points to do certain things:

In short, the user writes their model as a function, but we require the model as a generator.

I opine on why this problem is challenging a lot more here.

The Solution

First, I wrote a FunctionToGenerator class:

class FunctionToGenerator(ast.NodeTransformer):
    """
    This subclass traverses the AST of the user-written, decorated,
    model specification and transforms it into a generator for the
    model. Subclassing in this way is the idiomatic way to transform
    an AST.
    Specifically:

    1. Add `yield` keywords to all assignments
       E.g. `x = tfd.Normal(0, 1)` -> `x = yield tfd.Normal(0, 1)`
    2. Rename the model specification function to
       `_pm_compiled_model_generator`. This is done out an abundance
       of caution more than anything.
    3. Remove the @Model decorator. Otherwise, we risk running into
       an infinite recursion.
    """
    def visit_Assign(self, node):
        new_node = node
        new_node.value = ast.Yield(value=new_node.value)

        # Tie up loose ends in the AST.
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(node)
        return new_node

    def visit_FunctionDef(self, node):
        new_node = node
        new_node.name = "_pm_compiled_model_generator"
        new_node.decorator_list = []

        # Tie up loose ends in the AST.
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(node)
        return new_node

Subclassing ast.NodeTransformer (as FunctionToGenerator does) is the recommended way of modifying ASTs. The functionality of FunctionToGenerator is pretty well described by the docstring: the visit_Assign method adds the yield keyword to all assignments by wrapping the visited Assign node within a Yield node. The visit_FunctionDef method removes the decorator and renames the function to _pm_compiled_model_generator. All told, after the NodeTransformer is done with the AST, we have one function, _pm_compiled_model_generator, which is a modified version of the user-defined function.

Second, the Model class:

class Model:
    """ pm.Model decorator. """

    def __init__(self, func):
        self.func = func

        # Introspect wrapped function, instead of the decorator class.
        functools.update_wrapper(self, func)

        # Uncompile wrapped function.
        uncompiled = uncompile(func.__code__)

        # Parse AST and modify it.
        tree = parse_snippet(*uncompiled)
        tree = FunctionToGenerator().visit(tree)
        uncompiled[0] = tree

        # Recompile wrapped function.
        self.recompiled = recompile(*uncompiled)

        # Execute recompiled code (defines `_pm_compiled_model_generator`)
        # in the locals() namespace and assign it to an attribute.
        # Refer to http://lucumr.pocoo.org/2011/2/1/exec-in-python/
        exec(self.recompiled, None, locals())
        self.model_generator = locals()["_pm_compiled_model_generator"]

This class isn’t meant to be instantiated: rather, it’s meant to be used as a Python decorator. Essentially, it “uncompiles” the function to get the Python source code of the function. This source code is then passed to the parse_snippet3 function, which returns the AST for the function. We then modify this AST with the FunctionToGenerator class that we defined above. Finally, we recompile this AST and execute it. Recall that executing this recompiled AST defines a new function called _pm_compiled_model_generator. This new function, accessed via the locals variable4, is then bound to the class’s self.model_generator, which explains the confusing-looking self.model_generator = locals()["_pm_compiled_model_generator"].

Finally, the user facing API looks like this:

@Model
def linear_regression(x):
    scale = tfd.HalfCauchy(0, 1)
    coefs = tfd.Normal(tf.zeros(x.shape[1]), 1)
    predictions = tfd.Normal(tf.linalg.matvec(x, coefs), scale)
    return predictions


linear_regression.model_generator(tf.zeros([3, 10]))  # Shape is irrelevant here

# Out[8]:
# <generator object _pm_compiled_model_generator at 0x107a5c5c8>

As you can see, the users need not write yield while specifying their models, and the PyMC inference engine can now simply call the model_generator method of linear_regression to produce a generator called _pm_compiled_model_generator, as desired. Success!

Lessons Learnt

Again, PyMC4’s model specification API will not be incorporating these changes: the PyMC developers have since decided that the yield keyword is the most elegant (but not necessarily the easiest) way for users to specify statistical models. This post is just meant to summarize the lessons learnt while pursuing this line of inquiry.

Reading and parsing the AST is perfectly safe: that’s basically just a form of code introspection, which is totally a valid thing to do! It’s when you want to modify or even rewrite the AST that things start getting janky dangerous (especially if you want to execute the modified AST instead of the written code, as I was trying to do!).

If you want to programmatically modify the AST (e.g. “insert a yield keyword in front of every assignment of a TensorFlow Distribution”, as in our case), stop and consider if you’re attempting to modify the semantics of the written code, and if you’re sure that that’s a good idea (e.g. the yield keywords in the code mean something, and remove those keywords changes the apparent semantics of the code).

Further Reading

I’ve only given a high-level overview of this project here, and a lot of the technical details were glossed over. If you’re hungry for more, check out the following resources:


  1. Or should I say, complicating? At any rate, changing! ↩︎

  2. I was subsequently convinced that this isn’t a leaky abstraction after all. ↩︎

  3. I omitted the implementation of parse_snippet for brevity. If you want to see it, check out the “AST Helper Functions” section of this notebook↩︎

  4. For way more information on exec, eval, locals and globals, check out Armin Ronacher’s blog post and this StackOverflow answer↩︎

#Pymc #Python