40

Not sure if this question belongs here, but it's closely related to gradient methods in optimization, which seems to be on-topic here. Anyway, feel free to migrate if you think some other community has better expertise in the topic.

In short, I'm looking for a step-by-step example of reverse-mode automatic differentiation. There's not that much literature on the topic out there and existing implementation (like the one in TensorFlow) are hard to understand without knowing the theory behind it. Thus I'd be very thankful if somebody could show in detail what we pass in, how we process it and what we take out of computational graph.

A couple of questions that I have most difficulties with:

  • seeds - why do we need them at all?
  • reverse differentiation rules - I know how to make forward differentiation, but how do we go backward? E.g. in the example from this section, how do we know that $\bar{w_2}=\bar{w_3}w_1$?
  • do we work with symbols only or pass through actual values? E.g. in the same example, are $w_i$ and $\bar{w_i}$ symbols or values?
ffriend
  • 9,380
  • 5
  • 24
  • 29

1 Answers1

50

Let's say we have expression $z = x_1x_2 + \sin(x_1)$ and want to find derivatives $\frac{dz}{dx_1}$ and $\frac{dz}{dx_2}$. Reverse-mode AD splits this task into 2 parts, namely, forward and reverse passes.

Forward pass

First, we decompose our complex expression into a set of primitive ones, i.e. expressions consisting of at most single function call. Note that I also rename input and output variables for consistency, though it's not necessary:

$$w_1 = x_1$$ $$w_2 = x_2$$ $$w_3 = w_1w_2$$ $$w_4 = \sin(w_1)$$ $$w_5 = w_3 + w_4$$ $$z = w_5$$

The advantage of this representation is that differentiation rules for each separate expression are already known. For example, we know that derivative of $\sin$ is $\cos$, and so $\frac{dw_4}{dw_1} = \cos(w_1)$. We will use this fact in reverse pass below.

Essentially, forward pass consists of evaluating each of these expressions and saving the results. Say, our inputs are: $x_1 = 2$ and $x_2 = 3$. Then we have:

$$w_1 = x_1 = 2$$ $$w_2 = x_2 = 3$$ $$w_3 = w_1w_2 = 6$$ $$w_4 = \sin(w_1) ~= 0.9$$ $$w_5 = w_3 + w_4 = 6.9$$ $$z = w_5 = 6.9$$

Reverse pass

This is were the magic starts, and it starts with the chain rule. In its basic form, chain rule states that if you have variable $t(u(v))$ which depends on $u$ which, in its turn, depends on $v$, then:

$$\frac{dt}{dv} = \frac{dt}{du}\frac{du}{dv}$$

or, if $t$ depends on $v$ via several paths / variables $u_i$, e.g.:

$$u_1 = f(v)$$ $$u_2 = g(v)$$ $$t = h(u_1, u_2)$$

then (see proof here):

$$\frac{dt}{dv} = \sum_i \frac{dt}{du_i}\frac{du_i}{dv}$$

In terms of expression graph, if we have a final node $z$ and input nodes $w_i$, and path from $z$ to $w_i$ goes through intermediate nodes $w_p$ (i.e. $z = g(w_p)$ where $w_p = f(w_i)$), we can find derivative $\frac{dz}{dw_i}$ as

$$\frac{dz}{dw_i} = \sum_{p \in parents(i)} \frac{dz}{dw_p} \frac{dw_p}{dw_i}$$

In other words, to calculate the derivative of output variable $z$ w.r.t. any intermediate or input variable $w_i$, we only need to know the derivatives of its parents and the formula to calculate derivative of primitive expression $w_p = f(w_i)$.

Reverse pass starts at the end (i.e. $\frac{dz}{dz}$) and propagates backward to all dependencies. Here we have (expression for "seed"):

$$\frac{dz}{dz} = 1$$

That may be read as "change in $z$ results in exactly the same change in $z$", which is quite obvious.

Then we know that $z = w_5$ and so:

$$\frac{dz}{dw_5} = 1$$

$w_5$ linearly depends on $w_3$ and $w_4$, so $\frac{dw_5}{dw_3} = 1$ and $\frac{dw_5}{dw_4} = 1$. Using the chain rule we find:

$$\frac{dz}{dw_3} = \frac{dz}{dw_5} \frac{dw_5}{dw_3} = 1 \times 1 = 1$$ $$\frac{dz}{dw_4} = \frac{dz}{dw_5} \frac{dw_5}{dw_4} = 1 \times 1 = 1$$

From definition $w_3 = w_1w_2$ and rules of partial derivatives, we find that $\frac{dw_3}{dw_2} = w_1$. Thus:

$$\frac{dz}{dw_2} = \frac{dz}{dw_3} \frac{dw_3}{dw_2} = 1 \times w_1 = w_1$$

Which, as we already know from forward pass, is:

$$\frac{dz}{dw_2} = w_1 = 2$$

Finally, $w_1$ contributes to $z$ via $w_3$ and $w_4$. Once again, from the rules of partial derivatives we know that $\frac{dw_3}{dw_1} = w_2$ and $\frac{dw_4}{dw_1} = \cos(w_1)$. Thus:

$$\frac{dz}{dw_1} = \frac{dz}{dw_3} \frac{dw_3}{dw_1} + \frac{dz}{dw_4} \frac{dw_4}{dw_1} = w_2 + \cos(w_1)$$

And again, given known inputs, we can calculate it:

$$\frac{dz}{dw_1} = w_2 + \cos(w_1) = 3 + \cos(2) ~= 2.58$$

Since $w_1$ and $w_2$ are just aliases for $x_1$ and $x_2$, we get our answer:

$$\frac{dz}{dx_1} = 2.58$$ $$\frac{dz}{dx_2} = 2$$

And that's it!


This description concerns only scalar inputs, i.e. numbers, but in fact it can also be applied to multidimensional arrays such as vectors and matrices. Two things that one should keep in mind when differentiating expressions with such objects:

  1. Derivatives may have much higher dimensionality than inputs or output, e.g. derivative of vector w.r.t. vector is a matrix and derivative of matrix w.r.t. matrix is a 4-dimensional array (sometimes referred to as a tensor). In many cases such derivatives are very sparse.
  2. Each component in output array is an independent function of 1 or more components of input array(s). E.g. if $y = f(x)$ and both $x$ and $y$ are vectors, $y_i$ never depends on $y_j$, but only on subset of $x_k$. In particular, this means that finding derivative $\frac{dy_i}{dx_j}$ boils down to tracking how $y_i$ depends on $x_j$.

The power of automatic differentiation is that it can deal with complicated structures from programming languages like conditions and loops. However, if all you need is algebraic expressions and you have good enough framework to work with symbolic representations, it's possible to construct fully symbolic expressions. In fact, in this example we could produce expression $\frac{dz}{dw_1} = w_2 + \cos(w_1) = x_2 + \cos(x_1)$ and calculate this derivative for whatever inputs we want.

liang
  • 103
  • 3
ffriend
  • 9,380
  • 5
  • 24
  • 29
  • 1
    Very useful question/answer. Thanks. Just a litte criticism: you seem to move on a tree structure without explaining (that's when you start talking about parents, etc..) – MadHatter May 28 '17 at 12:48
  • 1
    Also it won't hurt clarifying why we need seeds. – MadHatter May 28 '17 at 13:11
  • @MadHatter thanks for the comment. I tried to rephrase a couple of paragraphs (these that refer to parents) to emphasize a graph structure. I also added "seed" to the text, although this name itself may be misleading in my opinion: in AD seed is always a fixed expression - $\frac{dz}{dz} = 1$, not something you can choose or generate. – ffriend May 28 '17 at 21:56
  • Thanks! I noticed when you have to set more than one "seed", generally one chooses 1 and 0. I'd like to know why. I mean, one takes the "quotient" of a differential w.r.t. itself, so "1" is at least intuitively justified.. But what about 0? And what if one has to pick more than 2 seeds? – MadHatter May 30 '17 at 08:26
  • 1
    As far as I understand, more than one seed is used only in forward-mode AD. In this case you set the seed to 1 for an input variable you want to differentiate with respect to and set the seed to 0 for all the other input variables so that they don't contribute to the output value. In reverse-mode you set the seed to an _output_ variable, and you normally have only one output variable. I guess, you can construct reverse-mode AD pipeline with several output variables and set all of them but one to 0 to get the same effect as in forward mode, but I have never investigated this option. – ffriend May 30 '17 at 08:53
  • Very good, that clarifies a lot. Thanks again, there is so few people who understand the subtleties of an autograd, it's not easy to get answers.. – MadHatter May 31 '17 at 10:43
  • @Imabot thanks for your comment (via a suggested edit), I updated the text to clarify. That specific part of description is not about different nodes in a single chain (where you indeed use multiplication), but about several parallel chains. E.g. in the very first example $w_5$ depends on $w_1$ via both - $w_3$ and $w_4$. In this case you need to calculate derivatives for 2 chains: $w_1 \rightarrow w_3 \rightarrow w_5$ and $w_1 \rightarrow w_4 \rightarrow w_5$, and then simply add them. – ffriend Dec 08 '17 at 11:03
  • 2
    The proof has moved and is now [here](https://math.hmc.edu/calculus/hmc-mathematics-calculus-online-tutorials/multivariable-calculus/multi-variable-chain-rule/) – Metaxal Apr 12 '20 at 12:26
  • This is a great description! I wonder if you could elaborate on the line you started with “ However, if all you need is algebraic expressions...” I don’t quite see how this is different from high school calculus (symbolic differentiation). Are we just splitting apart an expression into simpler expressions (‘A’ normal form), then applying basic rules of calculus? – Shahbaz Jul 05 '20 at 22:54
  • @Shahbaz Yes, basically in this section I refer to symbolic differentiation. Note though that symbolic derivatives become more complicated when you start working with multidimensional data. There are some approaches to deal with it (e.g. [I tried](https://github.com/dfdx/XDiff.jl) one based on Einstein indexing notation), but it's more of research topic than day-to-day approach. In implementations, there are tons of details and options (e.g. operator overloading, tracing, source-to-source), but basic idea of reverse-mode AD stays mostly the same. – ffriend Jul 08 '20 at 21:38
  • Hi, thanks for this amazing explanation. Just wanted to ask you, if this is within the scope of this question, how does this work in a programming language? I can understand this because I know derivatives, like the derivative of sin is cos. But in a code, we can't do this for every function. So how does it work in a code? – Sarvagya Gupta Dec 02 '20 at 20:00
  • 1
    @SarvagyaGupta Usually you have a set of "primitive" functions that you define derivatives for, and then complex functions composed of these primitives. As an example of practical autodiff engine, take a look at [PyTorch version of AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html). – ffriend Dec 02 '20 at 23:07