Gradient Descent

Fundamentally, all of machine learning boils down to trying to get an output value to match what we expect. Behind a model is just a bunch of mathematical operations that produce a value (or set of values). The way we get the output values to match our expectations is through gradient descent.

Gradient descent describes how we get as close as possible to the expected output value. We measure how each input affects the output at each stage of the calculation, and try to adjust the inputs (weights) so that the output value gets closer to our goal. This is done through partial derivatives.

The formula to find the gradient through central difference is: . You can notice the similarity to the definition of the derivative, just using a finite value instead of an infinitesimally small one. A notable difference, however is that we sample to the right and left of our input, and find the change over a distance of instead of . This is just a way to minimize the error of our approximation as much as possible. The proof that this works is left as an exercise to the reader.

This works well enough for some applications, but can create significant compounding errors down the line due to the fact that it isn’t exact. Fortunately, we can do a lot better! Once again, if you’ve taken Calc 1, you would have been introduced to the chain rule. Otherwise, here is a quick rundown: To find the derivative of a composite function, , we can find the derivative of the outer function, and multiply by the derivative of the inner function, .

We can use the chain rule in conjunction with the known derivatives of commonly used functions to determine the gradient of any function, which is perfect for us! That’s where auto differentiation comes in. Auto differentiation is a better way of finding the gradient of our inputs. Conceptually, we keep track of how each input is used in the function, and continuously apply the chain rule to determine its gradient.

Here is a concrete example: given the function , if we want to find the gradient with respect to, we first note that the derivative of a function is , and the derivative of a function is . Then we can rewrite as , and apply the chain rule:

Awesome, now we have a way to get the exact gradient of an input. A new issue arises, however: if we’re only given the output at the end and not the function itself, how do we determine how much each of our inputs affected the output? The solution to this is to store the history of each number, including the intermediate values calculated.

[insert image of the computation graph and how to interpret it]

For the example above, we would store . Then we would store the result of , the function that was used to produce the value () and the inputs to that function (including their history recursively). This would repeat for where the function is , and the input is the result of itself.

If we store each number’s history using unique recursive numbers, each with its own history, there would be a LOT of redundancy and inefficiency. We can easily notice that the input of the function is , a number that we need to store the history and inputs for anyway. Therefore, we can just reuse the object storing that information and point to it in our object to represent the value. In other words, we create a new object type, say Scalar , that stores the Scalars used as inputs, and the function used in the computation.

Any mathematical function and its intermediate values can be decomposed into a graph of Scalar objects. Moreover, the graphed form is acyclic, meaning it can be easily ordered. This graph, known as the computation graph, is an important part of gradient descent.

To calculate the output value for a given set of inputs, we traverse the computation graph in forward order, applying the functions at each node to the inputs, and saving any context necessary to apply the chain rule later. To find the gradients of each input, we then traverse the graph in reverse order, applying the chain rule at each node and accumulating the partial derivatives of the inputs.