35.1 Gradient descent

Gradient descent is a general purpose optimization algorithm that is used (or most frequently, a variant is used) throughout many machine learning applications. When thinking about gradient descent conceptually, the scenario described previously of walking around a meadow blindfolded is again useful. If we were trying to find the lowest point, we would probably feel around us and find the direction that seemed the steepest. We would then take a step in that direction. If we were being careful, we would check the ground around us again after each step, evaluating the gradient immediately around us, and then continuing on in the steepest direction. This is the basic idea of gradient descent. We start off in a random location on some surface, and we step in the steepest direction until we can’t go down any further.

Consider a simple example with a linear regression model. First, let’s simulate some data.

n <- 1000
x <- rnorm(n)

a <- 5
b <- 1.3
e <- 4

y <- a + b*x + rnorm(n, sd = e)

sim_d <- tibble(x = x, y = y)

ggplot(sim_d, aes(x, y)) +

We can estimate the relation between \(x\) and \(y\) using standard OLS, as follows:

sim_ols <- lm(y ~ x)
## Call:
## lm(formula = y ~ x)
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -11.6899  -2.6353  -0.0333   2.6513  14.3506 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)   4.9797     0.1248   39.89   <2e-16 ***
## x             1.3393     0.1245   10.75   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## Residual standard error: 3.946 on 998 degrees of freedom
## Multiple R-squared:  0.1039, Adjusted R-squared:  0.103 
## F-statistic: 115.7 on 1 and 998 DF,  p-value: < 2.2e-16

This not only provides us the best linear unbiased estimate (BLUE) of the relation between \(x\) and \(y\), but it is computed extremely fast (around a tenth of a second on my computer).

Let’s see if we can replicate these values using gradient descent. We will be estimating two parameters, the intercept and the slope of the line (as above). Our objective function is the mean squared error. That is, we want to find the line running through the data that minimizes the average distance between the line and the points. In the case of simple linear regression, the mean square error is defined by

\[ \frac{1}{N} \sum_{i=1}^{n} (y_i - (a + bx_i ))^2 \]

where \(a\) is the intercept of the line and \(b\) is the slope of the line. Let’s write a function in R that computes the mean square error for any line.

mse <- function(a, b, x = sim_d$x, y = sim_d$y) {
  prediction <- a + b*x # model prediction, given intercept/slope
  residuals <- y - prediction # distance between prediction & observed
  squared_residuals <- residuals^2 # squared to avoid summation to zero
  ssr <- sum(squared_residuals) # sum of squared distances
  mean(ssr) # average of squared distances

Notice in the above we pre-defined the x and the y values to be from our simulated data, but the function is general enough to compute the mean squared error for any set of data.

Just to confirm that our function works, let’s check that our MSE with the OLS coefficients matches what we get from the model.

mse(a = coef(sim_ols)[1], b = coef(sim_ols)[2])
## [1] 15539.95

Is this the same thing we get if we compute it from the model residuals?

## [1] 15539.95

It is!

So now we have a general function that can be used to evaluate our objective function for any intercept/slope combination. We can then, theoretically, evaluate infinite combinations and find the lowest value. Let’s look at several hundred combinations.

# candidate values
grid <- expand.grid(a = seq(-5, 10, 0.1), b = seq(-5, 5, 0.1)) %>% 
## # A tibble: 15,251 x 2
##        a     b
##    <dbl> <dbl>
##  1  -5      -5
##  2  -4.9    -5
##  3  -4.8    -5
##  4  -4.7    -5
##  5  -4.6    -5
##  6  -4.5    -5
##  7  -4.4    -5
##  8  -4.3    -5
##  9  -4.2    -5
## 10  -4.1    -5
## # … with 15,241 more rows

We could, of course, overlay all of these on our data, but that would be really difficult to parse through with 15,251 candidate lines. Let’s compute the MSE for each candidate.

mse_grid <- grid %>% 
  rowwise(a, b) %>% 
  summarize(mse = mse(a, b), .groups = "drop")

## # A tibble: 15,251 x 3
##        a     b     mse
##    <dbl> <dbl>   <dbl>
##  1  -5      -5 152243.
##  2  -4.9    -5 150290.
##  3  -4.8    -5 148357.
##  4  -4.7    -5 146444.
##  5  -4.6    -5 144551.
##  6  -4.5    -5 142677.
##  7  -4.4    -5 140824.
##  8  -4.3    -5 138991.
##  9  -4.2    -5 137178.
## 10  -4.1    -5 135384.
## # … with 15,241 more rows

Let’s actually look at this grid

Notice this is basically just a big valley because this is a very simple (and linear) problem. We want to find the combination that minimizes the MSE which, in this case, is:

mse_grid %>% 
  arrange(mse) %>% 
## # A tibble: 1 x 3
##       a     b    mse
##   <dbl> <dbl>  <dbl>
## 1     5   1.3 15542.

How does this compare to what we estimated with OLS?

## (Intercept)           x 
##    4.979742    1.339272

Pretty similar.

But this still isn’t gradient descent. This is basically just a giant search algorithm that is only feasible because the problem is so simple.

So what is gradient descent and how do we implement it? Conceptually, it’s similar to our blindfolded walk - we start in a random location and try to walk downhill until we get to what we think is the lowest point. A more technical description is given in the box below.

Gradient Descent

  • Define a cost function (such as MSE).
  • Calculate the partial derivative of each parameter of the cost function. These provide the gradient (steepness), and the direction the algorithm needs to “move” to minimize the cost function.
  • Define a learning rate. This is the size of the “step” we take downhill.
  • Multiply the learning rate by the partial derivative value (this is how we actually “step” down).
  • Estimate a new gradient and continue iterating (\(\text{gradient} \rightarrow \text{step} \rightarrow \text{gradient} \rightarrow \text{step} \dots\)) until no further improvements are made.

Let’s try applying gradient descent to our simple linear regression problem. First, we have to take the partial derivative of each parameter, \(a\) and \(b\), for our cost function. The is defined as

\[ \begin{bmatrix} \frac{d}{da}\\ \frac{d}{db}\\ \end{bmatrix} = \begin{bmatrix} \frac{1}{N} \sum -2(y_i - (a + bx_i)) \\ \frac{1}{N} \sum -2x_i(y_i - (a + bx_i)) \\ \end{bmatrix} \]

Let’s write a function to calculate the gradient (partial derivative) for any values of the two parameters. Similar to the mse() function we wrote, we’ll define this assuming the sim_d data, but have it be general enough that other data could be provided.

compute_gradient <- function(a, b, x = sim_d$x, y = sim_d$y) {
  n <- length(y)
  predictions <- a + (b * x)
  residuals <- y - predictions
  da <- (1/n) * sum(-2*residuals)
  db <- (1/n) * sum(-2*x*residuals)
  c(da, db)

Great! Next, we’ll write a function that uses the above function to calculate the gradient, but then actually takes a step in that direction. We do this by first multiplying our partial derivatives by our learning rate (the size of each step) and then subtracting that value from whatever the parameters are currently. We subtract because we’re trying to go “downhill”. If we were trying to maximize our objective function, we would add these values to our current parameters (technically gradient ascent).

Learning rate defines the size of the step we take downhill. Higher learning rates will get us closer to the optimal solution faster, but may “step over” the minimum. When training a model, start with a relatively high learning rate (e.g., \(0.1\)) and adjust as needed. Before finalizing your model, consider reducing the learning rate to ensure you’ve found the global minimum.

gd_step <- function(a, b, 
                    learning_rate = 0.1, 
                    x = sim_d$x, 
                    y = sim_d$y) {
  grad <- compute_gradient(a, b, x, y)
  step_a <- grad[1] * learning_rate
  step_b <- grad[2] * learning_rate
  c(a - step_a, b - step_b)

And finally, we choose a random location to start, and begin our walk! Let’s begin at 0 for each parameter.

walk <- gd_step(0, 0)
## [1] 0.9890313 0.2433964

After just a single step, our parameters have changed quite a bit. Remember that our true values are 5 and 1.3. Both parameters appear to be heading in the right direction. Let’s take a few more steps. Notice that in the below we’re taking a step from the previous location.

walk <- gd_step(walk[1], walk[2])
## [1] 1.7815134 0.4429925
walk <- gd_step(walk[1], walk[2])
## [1] 2.4165300 0.6065743
walk <- gd_step(walk[1], walk[2])
## [1] 2.9253881 0.7405655

Our parameters continue to head in the correct direction. However, the amount that the values change gets less with each iteration. This is because the gradient is less steep. So our “stride” is not carrying us as far, even though the size of our step is the same.

Let’s speed this up a bit (although you could continue on to “watch” the parameters change) by using a loop to quickly take 25 more steps downhill.

for(i in 1:25) {
  walk <- gd_step(walk[1], walk[2])
## [1] 4.971523 1.335808

And now we are very close! What if we took 25 more steps?

for(i in 1:25) {
  walk <- gd_step(walk[1], walk[2])
## [1] 4.979709 1.339254

We get almost exactly the same thing. Why? Because we were already basically there. If we continue to try to go downhill we just end up walking around in circles.

Let’s rewrite our functions a bit to make the results a little easier to store and inspect later.

estimate_gradient <- function(pars_tbl, learning_rate = 0.1, x = sim_d$x, y = sim_d$y) {
  pars <- gd_step(pars_tbl[["a"]], pars_tbl[["b"]],
  tibble(a = pars[1], b = pars[2], mse = mse(a, b, x, y))

# initialize
grad <- estimate_gradient(tibble(a = 0, b = 0))

# loop through
for(i in 2:50) {
  grad[i, ] <- estimate_gradient(grad[i - 1, ])
## # A tibble: 50 x 3
##        a     b    mse
##    <dbl> <dbl>  <dbl>
##  1 0.989 0.243 32446.
##  2 1.78  0.443 26428.
##  3 2.42  0.607 22552.
##  4 2.93  0.741 20057.
##  5 3.33  0.850 18450.
##  6 3.66  0.940 17415.
##  7 3.92  1.01  16748.
##  8 4.13  1.07  16318.
##  9 4.30  1.12  16042.
## 10 4.43  1.16  15863.
## # … with 40 more rows

Finally, let’s add our iteration number into the data frame, and plot it.

grad <- grad %>% 

ggplot(grad, aes(iteration, mse)) +

As we would expect, the MSE drops very quickly as we start to walk downhill (with each iteration) but eventually (around 20 or so iterations) starts to level out when no more progress can be made.

Let’s look at this slightly differently. by looking at our cost surface, and how we step down the cost surface.

You can see that, as we would expect, the algorithm takes us straight “downhill”.

Finally, because this is simple linear regression, we can also plot out the line through the iterations (as the algorithm “learns” the optimal intercept/slope combination).

ggplot(sim_d, aes(x, y)) +
  geom_point() +
  geom_abline(aes(intercept = a, slope = b),
              data = grad,
              color = "gray60",
              size = 0.3) +
  geom_abline(aes(intercept = a, slope = b),
              data = grad[nrow(grad), ],
              color = "magenta")

Or, just for fun, we could animate it.

ggplot(grad) +
  geom_point(aes(x, y), sim_d) +
  geom_smooth(aes(x, y), sim_d, 
              method = "lm", se = FALSE) +
  geom_abline(aes(intercept = a,
                  slope = b),
              color = "#de4f60") +
  transition_manual(frames = iteration)

So in the end, gradient descent gets us essentially the exact same answer we get with OLS. So why would we use this approach instead? Well, if we’re estimating a model with something like regression, we wouldn’t. As is probably apparent, this approach is going to be more computationally expensive, particularly if we have a poor starting location. But in many cases we don’t have a closed-form solution to the problem. This is where gradient descent (or a variant thereof) can help. In the case of boosted trees, we start by building weak model with only a few splits (including potentially only one). We then build a new (weak) model from the residuals of this model, using gradient descent to optimize each split in the tree relative to the cost function. This ensures that our ensemble of trees builds toward the minimum (or maximum) of our cost function.

As is probably clear, however, gradient descent is a fairly complicated topic. We chose to focus on a relatively “easy” implementation within a simple linear regression framework, which has only two parameters. There are a number of potential pitfalls related to gradient descent, including exploding gradients, where the search algorithm gets “lost” and basically wanders out into space. This can happen if the learning rate is too high. Normalizing your features so they are all on a common scale can also help prevent exploding gradients. Vanishing gradients is a similar problem, where the gradient is small enough that the new parameters do not change much and the algorithm gets “stuck”. This is okay (and expected) if the algorithm has reached the global minimum, but is a problem if it’s only at a local minimum (i.e., stuck in a valley on a hill). There are numerous variants of gradient descent that can help with these challenges, as we will discuss in subsequent chapters.