Optimization

Lecture 13

Dr. Colin Rundel

Optimization

Optimization problems underlie nearly everything we do in Machine Learning and Statistics. Most models can be formulated as

\[ P \; : \; \underset{\boldsymbol{x} \in \boldsymbol{D}}{\text{arg min}} f(\boldsymbol{x}) \]

  • Formulating a problem \(P\) is not the same as being able to solve \(P\) in practice

  • Many different algorithms exist for optimization but their performance varies widely depending the exact nature of the problem

Gradient Descent

Naive Gradient Descent

The basic idea behind this approach is that the gradient of a function tells us the direction of steepest ascent (or descent). Therefore, to find the minimum we should take our next step in the direction of the negative gradient to most quickly approach the nearest minima.


Given an \(n\)-dimensional function \(f(x_1, \ldots, x_n)\), and an initial position \(\boldsymbol{x}_k\) then our update rule becomes,

\[ \boldsymbol{x}_{k+1} = \boldsymbol{x}_{k} - \alpha \nabla f(\boldsymbol{x}_k) \]

here \(\alpha\) refers to the step length or the learning rate which determines how big a step we will take.

Implementation

def grad_desc_1d(x, f, grad, step, max_step=100, tol = 1e-6):
  res = {"x": [x], "f": [f(x)]}

  try:
    for i in range(max_step): 
      x = x - grad(x) * step
      if np.abs(x - res["x"][-1]) < tol: 
        break

      res["f"].append( f(x) )
      res["x"].append( x )
      
  except OverflowError as err:
    print(f"{type(err).__name__}: {err}")
  
  if i == max_step-1:
    warnings.warn("Failed to converge!", RuntimeWarning)
  
  return res

A basic example

\[ \begin{aligned} f(x) &= x^2 \\ \nabla f(x) &= 2x \end{aligned} \]

f = lambda x: x**2
grad = lambda x: 2*x
opt = grad_desc_1d(-2., f, grad, step=0.25)
plot_1d_traj( (-2, 2), f, opt )

opt = grad_desc_1d(-2., f, grad, step=0.5)
plot_1d_traj( (-2, 2), f, opt )

Where can it go wrong?

If you pick a bad step size …

opt = grad_desc_1d(-2, f, grad, step=0.9)
plot_1d_traj( (-2,2), f, opt )

opt = grad_desc_1d(-2, f, grad, step=1)
plot_1d_traj( (-2,2), f, opt )

Local minima

The function now has multiple minima, both starting point and step size affect the solution we obtain,

\[ \begin{aligned} f(x) &= x^4 + x^3 -x^2 - x \\ \nabla f(x) &= 4x^3 + 3x^2 - 2x - 1 \end{aligned} \]

f = lambda x: x**4 + x**3 - x**2 - x 
grad = lambda x: 4*x**3 + 3*x**2 - 2*x - 1
opt = grad_desc_1d(-1.5, f, grad, step=0.2)
plot_1d_traj( (-1.5, 1.5), f, opt )

opt = grad_desc_1d(-1.5, f, grad, step=0.25)
plot_1d_traj( (-1.5, 1.5), f, opt)

Alternative starting points

opt = grad_desc_1d(1.5, f, grad, step=0.2)
plot_1d_traj( (-1.75, 1.5), f, opt )

opt = grad_desc_1d(1.25, f, grad, step=0.2)
plot_1d_traj( (-1.75, 1.5), f, opt)

Problematic step sizes

If the step size is too large it is possible for the algorithm to overflow,

opt = grad_desc_1d(-1.5, f, grad, step=0.75)
OverflowError: (34, 'Result too large')
plot_1d_traj( (-3, 3), f, opt)

opt['x']
[-1.5, 2.0625, -29.986083984375, 78789.99556875888, -1467366557235808.0, 9.478445237313853e+45]
opt['f']
[0.9375, 20.552993774414062, 780666.4923959533, 3.853805712579921e+19, 4.636117851941789e+60, 8.071391646153008e+183]

Gradient Descent w/ backtracking

As we have just seen having too large of a step can be problematic, one solution is to allow the step size to adapt.

Backtracking involves checking if the proposed move is advantageous (i.e. \(f(x_k+\alpha p_k) < f(x_k)\)),

  • If it is downhill then accept \(x_{k+1} = x_k+\alpha p_k\).

  • If not, adjust \(\alpha\) by a factor \(\tau\) (e.g. 0.5) and check again.

Pick larger \(\alpha\) to start (but not so large so as to overflow) and then let the backtracking tune things.

Implementation

def grad_desc_1d_bt(x, f, grad, step, tau=0.5, max_step=100, max_back=10, tol = 1e-6):
  res = {"x": [x], "f": [f(x)]}
  
  for i in range(max_step):
    grad_f = grad(x)
    for j in range(max_back):
      x = res["x"][-1] - step * grad_f
      f_x = f(x)
      if (f_x < res["f"][-1]): 
        break
      step = step * tau
    
    if np.abs(x - res["x"][-1]) < tol: 
      break
    res["x"].append(x)
    res["f"].append(f_x)
    
  if i == max_step-1:
    warnings.warn("Failed to converge!", RuntimeWarning)
  
  return res

opt = grad_desc_1d_bt(
  -1.5, f, grad, step=0.75, tau=0.5
)
plot_1d_traj( (-1.5, 1.5), f, opt )

opt = grad_desc_1d_bt(
  1.5, f, grad, step=0.25, tau=0.5
)
plot_1d_traj( (-1.5, 1.5), f, opt)

A 2d cost function

We will be using mk_quad() to create quadratic functions with varying conditioning (as specified by the epsilon parameter).

\[ \begin{align} f(x,y) &= 0.33(x^2 + \epsilon^2 y^2 ) \\ \nabla f(x,y) &= \left[ \begin{matrix} 0.66 \, x \\ 0.66 \, \epsilon^2 \, y \end{matrix} \right] \\ \nabla^2 f(x,y) &= \left[\begin{array}{cc} 0.66 & 0 \\ 0 & 0.66 \, \epsilon^2 \end{array}\right] \end{align} \]

Examples

f, grad, hess = mk_quad(0.7)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.7"
)

f, grad, hess = mk_quad(0.05)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.05"
)

\(n\)-d gradient descent w/ backtracking

def grad_desc(x, f, grad, step, tau=0.5, max_step=100, max_back=10, tol = 1e-8):
  res = {"x": [x], "f": [f(x)]}
  
  for i in range(max_step):
    grad_f = grad(x)
    
    for j in range(max_back):
      x = res["x"][-1] - grad_f * step
      f_x = f(x)
      if (f_x < res["f"][-1]): 
        break
      step = step * tau

    if np.sqrt(np.sum((x - res["x"][-1])**2)) < tol: 
      break  
    
    res["x"].append(x)
    res["f"].append(f_x)
      
  if i == max_step-1:
    warnings.warn("Failed to converge!", RuntimeWarning)
    
  return res

Well conditioned cost function

f, grad, hess = mk_quad(0.7)
opt = grad_desc((1.6, 1.1), f, grad, step=1)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.7", traj=opt
)

f, grad, hess = mk_quad(0.7)
opt = grad_desc((1.6, 1.1), f, grad, step=2)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.7", traj=opt
)

Ill-conditioned cost function

f, grad, hess = mk_quad(0.05)
opt = grad_desc((1.6, 1.1), f, grad, step=1)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.05", traj=opt
)

f, grad, hess = mk_quad(0.05)
opt = grad_desc((1.6, 1.1), f, grad, step=2)
plot_2d_traj(
  (-1,2), (-1,2), f, title="ε=0.05", traj=opt
)

Rosenbrock function (very ill conditioned)

f, grad, hess = mk_rosenbrock()
opt = grad_desc((1.6, 1.1), f, grad, step=0.25)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = grad_desc((-0.5, 0), f, grad, step=0.25)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Some regression examples

from sklearn.datasets import make_regression
X, y, coef = make_regression(
  n_samples=200, n_features=20, n_informative=4, 
  bias=3, noise=1, random_state=1234, coef=True
)
y
array([ -36.2252,    9.6357,   66.4583,   48.9574,   24.1885,  -13.2444,
         18.1455, -135.047 ,  116.5772,   60.2524,   30.9319,  107.148 ,
         21.6209,   66.2401, -132.8878,   58.636 ,   22.186 ,   60.3852,
        -85.0383,   55.1704,  -31.3817,  -57.0697,   67.3215,    2.878 ,
        -29.5613,  -41.3973,  -30.3048,  -41.5597,   52.7531,  -63.5633,
          3.5671,   63.712 ,    9.9833,   78.4881,  -76.126 ,   13.4331,
        122.6162,   79.0354,   91.2171,   48.7344,  103.6366,   52.5964,
         35.0064,  -65.8423,  -47.3045,  -25.6876,    1.8359,   35.4113,
         28.0687,   56.3528,    3.6755,  -72.3309,   57.143 ,  -16.9438,
         54.1445,   72.6828,   -5.0538, -180.6135,  -44.6205,    9.2071,
         -5.5324,  -29.6013,  135.3656,  114.241 ,  -97.4878,   15.0648,
         14.7958,   71.503 ,   -4.6583,  -36.791 ,   -5.3845, -119.8073,
         11.174 ,   36.3008,   82.5499,  -20.0869,   14.7146,  -59.0765,
         39.4171,   48.4013,  -61.9613,   -5.6247,  103.2374,   41.2613,
       -129.2273,   10.5113,   32.4936,   78.6921,    5.2956,   64.4473,
         88.8358,   39.4851,  -11.4866,  -52.5082,  112.7248,   -9.7006,
         13.8393,  -36.4004,   68.4865,   19.5335,  -75.447 ,  -87.9538,
         79.4784,  -75.094 ,   25.6229,   84.9034,   71.2779,  -66.4093,
         77.6444,   40.8875,   31.3165,  -22.7143,   84.562 ,    6.8075,
          9.778 ,  -65.9149,  106.6952,   -3.1901,   41.1555,   32.6265,
        -36.5738,   38.9966,  -78.664 ,  -56.0434,    2.9191,   42.6286,
         51.3644,  -21.8072,  -21.9779,  -15.7102,  -23.5586,    1.3801,
         20.4269,   55.7188,  -45.6388,  -55.1542,   74.6067,   -7.2716,
        -31.1045,   48.1571,   14.7487,   41.6956,  -59.6062,  -33.0811,
         81.0177,   -9.4896,  164.1317,   25.3507,    6.0141,   46.3718,
         84.2983,  -63.2593,  -17.4733,  -26.2977,  -56.4681,   17.003 ,
         53.1867,  -94.5398,  -18.2541,  -49.343 ,   40.8724,  -90.5986,
         27.9392,   41.7287,   49.8082,   -9.6384,  -66.7551,  122.9159,
        -41.3566,  -98.6863,  -45.0718,    9.9327,  -22.0927,   10.6199,
        -12.2831,    7.4184,   57.6091,  -27.3456,  -36.4045,  -51.659 ,
         28.8175,  -23.9402,  -51.0637,    4.3618,   10.8402,  -11.087 ,
        -29.9801,  113.6633,   66.5601,    1.3808,  -19.4875,   40.812 ,
         43.0652,   35.4802,   77.0732,  -49.7352,   65.7192,   73.8539,
        -59.4116,   72.9501])
X
array([[-0.6465,  2.0803,  0.1412, -0.8419, -0.1595,  1.3321, -0.4262,
        -0.0351, -0.1938, -0.6093, -0.3433,  0.6126,  0.3777, -1.2062,
        -0.2277, -0.8896, -0.4674, -1.3566,  1.4989, -0.7468],
       [-0.3834, -0.3631, -1.2196,  0.6   ,  0.3315,  1.1056,  0.2662,
        -0.7239,  0.0259, -0.2172, -0.6841,  0.0991,  0.2794, -1.208 ,
        -0.7818, -1.7348, -1.3397, -0.5723, -0.5882,  0.2717],
       [-0.1637, -0.8118,  0.9551,  0.5711,  0.8719, -0.9619,  1.9846,
        -1.1806, -1.1261,  0.297 ,  1.2499,  0.7109, -0.1183,  0.6708,
         0.6895,  1.4705,  0.0634, -0.3079, -2.2512, -0.0216],
       [-0.9292, -0.4897, -2.1196, -1.142 ,  1.266 , -0.2988,  1.0016,
        -2.1969, -1.0739, -0.1149,  0.5122,  0.302 , -0.0974,  1.3461,
         0.1909,  1.1223,  0.6268,  2.2035, -0.5135,  2.0118],
       [ 0.1645, -0.5847,  0.2708, -3.5635,  0.1526,  0.5283,  0.7674,
         1.392 , -0.0819,  1.3211,  0.4644, -1.0279,  0.9849, -1.069 ,
        -0.4301,  0.0798, -0.5119, -0.3448,  0.8166, -0.4   ],
       [ 0.4134,  1.9511, -0.5013, -1.4894,  0.4191, -1.4104,  0.2617,
        -0.6981,  0.0368, -1.151 ,  2.0752,  0.5001, -0.2428,  0.45  ,
         0.7176,  1.3846,  0.5155,  0.4459, -0.2784, -0.2864],
       [-0.0628, -1.424 , -1.1023,  0.1445, -0.4836,  1.4795, -0.5921,
         1.6423, -0.5013,  0.4435,  2.0044,  0.6221,  0.0747, -1.4117,
        -0.202 , -1.3071, -0.8656, -1.311 ,  0.0424,  0.7255],
       [-0.6642,  1.4317, -0.0658, -0.7379, -0.9153,  0.8653,  0.7143,
         1.0912, -1.3773, -2.6022, -0.2955, -0.3985,  0.0918,  0.3851,
         0.502 , -0.4665,  1.6432, -0.2438, -0.4943,  1.4753],
       [ 1.5247, -1.3419, -0.4453, -0.6141,  2.0632, -1.0742, -1.4419,
        -1.4923,  0.3135,  0.7691,  0.5383, -0.9741,  0.8457, -0.0014,
         0.3895,  0.2118,  1.0977, -0.4036, -1.5496, -0.3672],
       [ 1.5524, -1.1109, -0.5624, -0.9106, -0.0506,  0.8533, -0.5452,
        -1.7836, -0.8365,  2.171 , -0.6158,  0.2523, -1.8707,  0.6142,
         0.7962,  0.0706, -0.2386, -0.4144, -0.0898, -0.4745],
       [-0.7206, -2.0213,  0.0157, -1.191 , -0.3127,  0.2891,  0.8596,
        -2.2427,  0.0021,  1.4327,  0.4714,  0.9533, -0.6365,  1.3212,
         0.8872,  1.15  , -1.5469,  0.4055, -0.3341,  0.9919],
       [ 0.2328,  0.5523, -0.2356, -1.2547,  0.6686, -2.1204, -0.186 ,
         1.4915, -1.1353,  2.3889,  0.3449, -0.6703, -0.2358, -2.1923,
        -0.4635, -0.9962, -0.1116,  0.0605,  0.0027, -1.439 ],
       [ 1.8092, -1.5857, -0.9765,  1.6171,  0.368 , -0.2947,  1.5897,
        -0.8878,  1.0547, -0.0427, -1.187 ,  0.7605,  1.2381, -0.5014,
         1.0201, -0.5773, -0.632 , -0.502 , -1.6914,  0.803 ],
       [ 0.1935,  0.5289, -0.7559, -0.1047, -0.3334, -1.0275,  1.0327,
        -0.8811,  0.0483,  1.8504,  1.5727,  0.3325, -1.7398, -0.2383,
        -0.4967,  0.3939,  1.9322,  0.062 , -1.1205, -0.95  ],
       [ 0.6914, -2.0308,  1.1554, -0.4219, -1.6257, -0.1138, -0.9225,
        -1.9216,  1.2995, -1.5084, -0.863 ,  0.2528,  1.3636,  0.2059,
         0.0381,  1.1124,  1.73  ,  0.4496, -0.1806,  0.7681],
       [-0.0862, -0.2131, -0.5343, -0.1066, -0.8403,  1.3862,  0.5885,
        -1.089 , -0.8571,  2.0178,  2.6078, -0.5807, -0.3466, -0.5166,
        -0.7863,  0.2918, -0.1904, -0.8012, -1.6868,  0.2538],
       [-1.0714, -0.4582,  0.4255,  0.5657, -0.1743,  2.0978, -0.8453,
        -0.9807, -0.0414,  0.5851,  0.2645, -0.3602,  0.4151,  1.2829,
        -0.0485, -0.4278,  0.2703,  0.821 , -1.338 ,  1.4986],
       [ 0.2126,  2.1145, -0.1471,  1.7549,  0.9465, -1.3906, -1.0954,
        -0.5224,  0.5338,  0.0591, -0.2671,  1.5731,  0.3903,  0.6137,
        -0.5277,  0.6306,  0.7467,  1.7232,  0.62  ,  2.0249],
       [ 0.7085,  1.312 , -0.6134,  0.8665, -1.4706,  0.2597, -0.1606,
        -0.7118,  0.2154, -0.7415, -0.608 , -0.3412,  1.0772,  0.4695,
        -0.1285,  0.0654,  0.4922, -0.6707, -1.8229, -0.4215],
       [ 1.2418, -1.9068, -0.6066,  0.1639,  0.986 , -0.1853, -0.0303,
         1.152 , -0.161 ,  0.0226,  0.8991,  0.9874, -0.802 , -0.7241,
         0.2466,  0.747 , -0.9682, -1.1908,  0.4313, -1.2039],
       [-0.6263,  0.2757,  0.9388,  1.3835, -0.5935,  0.4409, -1.4681,
         0.0114, -0.3643, -0.3373, -1.3341,  0.0036,  0.5513, -0.1016,
         0.6814, -1.4258, -1.3869, -2.0679, -1.6482,  1.0062],
       [-0.2397,  0.6481,  1.7758,  0.0166, -1.7724, -0.1862,  1.118 ,
        -0.8409,  0.6136,  0.5269, -0.2908, -0.2294,  0.1747, -0.3881,
        -0.2667, -0.7601,  0.4313, -0.7488, -0.7594, -0.4084],
       [-1.0989,  1.1887, -0.5288,  1.6782,  0.3827,  0.4309, -1.3949,
         0.6801, -1.2572,  0.6585,  0.7674, -1.5397,  1.1786,  1.2429,
        -1.1094,  1.2524, -0.7556,  0.4051, -0.3198,  0.6704],
       [ 0.0582,  0.1247,  0.1058, -0.4947, -0.1381,  1.3226,  0.3375,
         0.0445,  1.2923,  0.67  , -1.3132, -0.7997, -0.1669,  1.5938,
        -0.7805, -0.3689, -2.5977, -1.2921, -1.2897, -0.074 ],
       [-0.9515, -1.0973,  1.5675,  0.0103, -1.1347,  0.165 ,  0.0289,
        -0.6242, -1.3193,  0.2246,  0.7557, -0.9032,  2.1041, -0.6316,
        -0.1271, -0.4006, -0.8671, -0.5601, -0.0713, -1.1371],
       [ 0.4599,  0.5513,  1.6362, -1.2392, -0.3352,  1.0237,  1.7626,
        -0.5441,  1.3217, -1.2237,  2.5112, -1.7501, -0.0857,  0.8239,
        -0.6406, -1.05  , -0.635 , -2.1445,  1.4129,  0.2546],
       [ 0.1871, -2.2206,  1.2475,  1.2345, -1.5021,  1.1434, -1.0406,
         0.0709,  1.2826,  0.5946, -0.176 ,  0.0639, -1.4364, -0.3326,
        -0.4648,  0.0733, -1.5075,  0.7799, -0.6549,  0.2562],
       [ 0.084 ,  0.9564,  0.366 , -0.6843, -0.6239,  0.3233, -0.4753,
        -0.7024, -0.8606, -0.8089,  1.7968, -0.9079, -0.1103, -0.8212,
         1.328 , -1.2039, -2.1219, -0.8672, -1.345 , -1.0769],
       [-0.807 , -0.037 ,  0.7597,  0.9556,  0.1334, -0.0225,  1.9088,
        -0.423 ,  0.267 ,  0.7138,  0.996 ,  0.0679,  0.1559,  0.1314,
        -0.342 ,  0.1817,  0.4344,  1.383 , -0.1708,  0.2745],
       [-0.3675, -0.93  ,  1.2117,  1.0203, -1.1554, -0.0461,  0.827 ,
         1.793 , -1.0029, -0.7901,  0.0797,  0.992 , -0.5725,  1.3592,
         1.2639,  1.3791,  0.021 , -2.5727, -0.2494,  2.0499],
       ...,
       [ 1.6243, -1.2413, -0.4177,  0.2389, -0.2734, -0.6785, -1.0147,
        -0.2772, -1.711 , -1.1543,  0.2933,  1.487 ,  0.7526,  0.6561,
         0.4132,  0.1095,  0.1406, -0.6598,  1.2687,  1.2148],
       [ 0.6854, -0.7399, -1.0681,  0.2991,  0.0382, -0.9321, -0.8341,
         0.0215,  0.0612, -0.129 ,  0.8795, -0.1681,  0.8851,  1.2921,
         0.3478,  1.5717,  2.4181, -0.0638,  1.3938,  0.884 ],
       [ 2.2326, -1.7645,  1.9779, -1.6875, -0.8401,  0.1057,  1.1688,
         0.3301, -0.5216,  1.207 , -1.5042,  1.6341, -1.0896, -0.7015,
        -1.7587,  1.4814,  0.6081, -0.7485,  2.1342, -0.4016],
       [ 2.2171, -0.6177,  0.1949, -1.0798,  0.586 , -0.859 ,  2.5508,
        -0.8039,  0.1503, -0.1069, -0.6496, -0.2479,  0.1649,  0.765 ,
         0.8986, -0.3648,  0.6722, -0.2408,  0.7112,  0.1551],
       [ 1.1797,  2.0229,  0.2965, -0.4986,  0.6617,  0.8841, -0.8252,
        -0.3799, -1.1173, -1.3918,  0.9206, -0.076 , -0.8812, -1.7954,
         0.2918,  0.7677,  0.4183,  1.236 , -0.1036, -0.0952],
       [-1.4325,  1.5241,  0.4914,  0.4466, -1.5217, -0.5697, -0.1623,
        -0.0357, -1.3161,  1.733 ,  0.7872,  1.4468,  1.8372,  1.0749,
        -2.0308, -0.2996, -1.1323,  0.6271,  0.7217,  0.8836],
       [ 0.4932, -1.2439,  0.5748, -0.1409,  1.7359, -1.1483, -0.4902,
        -0.5052, -0.4267, -0.6533,  0.242 ,  0.7283,  0.2963, -1.2347,
         0.6998,  0.7025, -0.5894, -2.7557, -1.1078, -0.5546],
       [ 0.7622, -1.2367, -0.9891,  1.7989,  0.1187, -1.8608, -0.559 ,
         0.775 ,  0.0616, -1.6046,  0.2385, -1.5886, -0.1833, -0.7817,
         1.8364, -0.5933, -0.3687,  0.3881,  1.2738,  1.2086],
       [-0.7038, -1.2434,  0.5626,  0.3224,  0.3713,  0.7815,  1.957 ,
         0.4423,  0.6326, -1.956 ,  1.1085,  0.1665,  0.841 ,  2.1472,
         0.5566, -0.2651, -0.9084,  2.0134,  0.3486,  1.2223],
       [-0.8597, -1.2391,  0.9525, -0.7438, -0.9162,  0.1223,  0.6288,
         0.9881, -0.223 , -0.3202,  0.5368, -0.9382,  0.1865, -1.4094,
         0.226 , -0.0726,  1.423 ,  2.1237,  0.1397, -0.5506],
       [-0.0956,  0.2007, -0.3942,  0.812 ,  0.6777, -2.5834, -1.3294,
        -0.6009, -1.0962, -0.359 ,  0.0455, -0.5706,  0.0263, -0.9308,
        -0.0649,  0.6586, -0.0469, -1.0283,  0.3524, -0.3676],
       [-1.4168, -0.0501, -0.5261, -0.3774, -0.9222,  0.253 , -0.2044,
         0.0524,  0.8973,  0.0268,  1.6289,  0.4335, -2.1098, -0.488 ,
         0.8635, -1.7442,  0.6374, -1.3713, -0.6505,  0.0552],
       [-0.1955, -0.6314,  1.0877, -0.9238, -1.2701,  0.3528,  0.9894,
         0.4388, -0.2405,  0.3558, -0.2266,  0.5029,  1.3886, -1.8156,
        -0.4634, -0.9616, -0.9101,  0.5856, -0.7043,  1.2456],
       [ 0.1021, -0.9809,  0.3537,  1.6762, -0.7037, -0.2284, -0.278 ,
        -0.4083,  0.666 ,  0.681 , -0.9055,  1.054 , -0.0522,  0.3645,
         1.1951, -1.8104, -1.5148,  1.0655,  0.3521, -0.9033],
       [ 0.2733, -0.0657, -0.5118, -0.9471, -0.4646,  1.4039,  0.5143,
         0.0839, -1.1759,  0.5828,  1.7003, -0.4077,  0.5941, -0.7209,
        -0.2797,  0.6193,  1.001 ,  1.6007, -0.6754,  0.2706],
       [ 1.0588, -2.1785,  0.2624,  0.1752, -0.0675,  0.0093, -1.8534,
        -1.7246,  1.5724, -0.5295,  0.4144,  0.6353, -0.7018, -1.1284,
        -0.1179,  0.2766, -1.2591,  2.0028,  0.312 ,  1.073 ],
       [ 0.428 , -0.2036, -1.3212,  1.1555, -0.439 , -1.6271, -0.1925,
         0.622 , -1.5984, -0.5565, -0.5401,  0.0358,  0.0133, -1.4583,
         2.0069,  0.1263,  0.4161, -0.9847,  0.7966,  0.486 ],
       [-0.1173,  0.6381, -0.2812, -0.3578,  1.9299,  2.1058,  2.3049,
         1.0974,  0.5727,  0.9601, -0.211 ,  0.6298,  0.7401, -1.7709,
        -1.0439,  0.8751, -1.5677,  0.3412, -0.1276, -0.8195],
       [ 0.2991, -0.7227, -0.4858,  0.2055,  0.8665,  0.5538,  0.3632,
         0.3877,  0.9835,  0.313 ,  1.5294, -0.3187,  1.8937,  0.3538,
         1.0765,  0.0236, -0.2756,  0.0235,  0.1774, -0.6602],
       [ 0.2811,  0.94  ,  0.2862,  0.2612, -0.3128,  1.8086, -0.1915,
         0.5764,  0.5697,  0.8199, -2.1527, -1.3712, -0.6986,  0.3284,
        -0.4288,  0.2109, -0.4925, -0.3614,  0.4904, -2.0268],
       [-0.1443, -1.0635, -1.0286, -0.5993, -0.0703,  1.1204, -0.0037,
        -0.4246, -0.3002, -0.7602,  1.4965,  2.5779,  0.7647,  0.1203,
         0.1931,  0.7629,  1.2026, -0.8532,  0.1837,  0.5151],
       [ 0.4857,  0.4299,  0.1458,  0.7444, -0.0979,  1.0103,  0.6701,
        -0.2955,  0.0965,  1.4683, -1.7762,  0.9793, -0.2762, -0.0483,
         0.0385, -0.5835,  0.6415,  0.0514,  0.3048,  1.6422],
       [ 0.4843,  0.7492,  1.2246,  0.5665,  0.2853, -1.8185, -0.7811,
        -1.2811, -0.1822,  0.5036,  0.2912, -0.4508, -0.468 ,  0.0471,
         1.3635,  0.8755,  0.3948,  0.6807, -0.2039, -1.7107],
       [-1.0209, -0.6634, -0.9163,  1.3832,  0.0976,  1.3228,  0.1802,
         0.2148,  1.6929,  0.4813, -0.2126,  0.7982, -1.1308, -0.6505,
         0.0008,  1.4194, -0.0373,  0.3867,  0.2306, -0.5345],
       [-0.7579,  1.397 ,  1.337 , -1.8439,  2.0066,  1.6138, -1.5925,
        -0.2433,  0.1118,  0.11  , -0.0658,  0.3186,  0.2924, -0.2974,
         1.016 , -0.231 ,  1.639 ,  0.4316, -0.8798, -0.3389],
       [ 0.7059,  1.752 , -0.0646, -0.0381, -0.2057,  0.6298, -0.0253,
        -1.205 , -0.1709, -0.6655, -2.0803,  0.4152, -0.1783,  0.8111,
        -2.6128, -3.8809,  2.1338,  0.7489,  0.485 ,  0.9745],
       [-1.1806, -0.3993, -0.2226, -2.3027,  1.3672, -0.1536, -1.9393,
         0.2231, -0.0869,  0.9348, -0.4875, -0.8112,  0.6287,  0.2276,
        -0.2909,  0.0516,  1.3023,  0.5602,  0.515 ,  0.4992],
       [-0.6427, -0.3821,  0.1691,  0.1282,  0.4937,  0.4471, -0.2522,
         0.8716, -1.5433,  1.6954, -0.9341, -0.5208, -1.0103,  0.9605,
        -1.877 ,  0.8052, -1.7248, -1.3744, -0.4322,  1.1537],
       [ 0.8806, -0.7055, -0.0158,  0.172 , -0.4213,  0.8811,  0.7577,
        -0.3878, -0.6382, -1.365 ,  0.1341,  1.7316, -0.6366, -0.6532,
        -1.4726,  0.8897, -1.32  ,  0.7008, -1.2858,  1.1342],
       [-0.5563, -1.0968, -0.9132, -1.2012,  1.8192, -1.1012,  0.805 ,
         2.0475,  2.122 ,  0.0836, -0.1508, -0.6059,  1.1388, -1.2726,
        -1.0138, -0.8948,  0.8435,  1.1215,  0.2176,  1.0591]],
      shape=(200, 20))
coef
array([ 0.    ,  0.    ,  0.    ,  9.6106, 43.4239,  0.    ,  0.    ,
        0.    ,  0.    , 34.453 ,  9.2929,  0.    ,  0.    ,  0.    ,
        0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ])

A jax implementation of GD

def grad_desc_jax(x, f, step, tau=0.5, max_step=100, max_back=10, tol = 1e-8):
  grad_f = jax.grad(f)
  f_x = f(x)
  converged = False

  for i in range(max_step):
    grad_f_x = grad_f(x)
    
    for j in range(max_back):
      new_x = x - grad_f_x * step
      new_f_x = f(new_x)
      if (new_f_x < f_x): 
        break
      step *= tau

    cur_tol = jnp.sqrt(jnp.sum((x - new_x)**2))

    x = new_x 
    f_x = new_f_x

    if cur_tol < tol: 
      converged = True
      break
  
  if not converged:
    warnings.warn("Failed to converge!", RuntimeWarning)

  return {
    "x": x,
    "n_iter": i,
    "converged": converged,
    "final_tol": tol,
    "final_step": step
  }

Linear regression

lm = LinearRegression().fit(X,y)
np.r_[lm.intercept_, lm.coef_]
array([ 3.0616, -0.0121, -0.0096,  0.096 ,  9.6955, 43.406 ,  0.0253,
        0.0284,  0.0962,  0.1069, 34.4884,  9.3445, -0.0165, -0.0147,
       -0.0396,  0.0969, -0.1057, -0.0943,  0.11  , -0.0096, -0.0875])
def jax_linear_regression(X, y, beta):
  Xb = jnp.c_[jnp.ones(X.shape[0]), X]
  return jnp.sum((y - Xb @ beta)**2)

grad_desc_jax(
  np.zeros(X.shape[1]+1), 
  lambda beta: jax_linear_regression(X,y,beta), 
  step = 1, tau = 0.5
)
{'x': Array([ 3.0617, -0.0121, -0.0097,  0.0961,  9.6958, 43.4061,  0.0255,
        0.0285,  0.0963,  0.1066, 34.488 ,  9.3445, -0.0166, -0.0146,
       -0.0396,  0.0966, -0.1056, -0.0941,  0.1101, -0.0099, -0.0879],      dtype=float32), 'n_iter': 18, 'converged': True, 'final_tol': 1e-08, 'final_step': 7.450580596923828e-09}

Ridge regression

r = GridSearchCV(Ridge(), param_grid = {"alpha": np.logspace(-3,0)}).fit(X,y)
r.best_estimator_
Ridge(alpha=np.float64(0.05963623316594643))
np.r_[r.best_estimator_.intercept_, r.best_estimator_.coef_]
array([ 3.0631, -0.0122, -0.0102,  0.0957,  9.6928, 43.3936,  0.0265,
        0.0282,  0.0959,  0.1074, 34.478 ,  9.3399, -0.0179, -0.0134,
       -0.0403,  0.0956, -0.106 , -0.0933,  0.1101, -0.0108, -0.0877])
def jax_ridge(X, y, beta, alpha):
  Xb = jnp.c_[jnp.ones(X.shape[0]), X]
  ls_loss = jnp.sum((y - Xb @ beta)**2)
  coef_loss = alpha * jnp.sum(beta[1:]**2)
  return ls_loss + coef_loss

grad_desc_jax(
  np.zeros(X.shape[1]+1), 
  lambda beta: jax_ridge(X, y, beta, r.best_estimator_.alpha), 
  step = 1, tau = 0.5, tol=1e-8
)
{'x': Array([ 3.0633, -0.0122, -0.0104,  0.0957,  9.6931, 43.3938,  0.0267,
        0.0282,  0.0961,  0.1071, 34.4775,  9.3398, -0.0179, -0.0134,
       -0.0404,  0.0952, -0.1059, -0.093 ,  0.1102, -0.011 , -0.0882],      dtype=float32), 'n_iter': 18, 'converged': True, 'final_tol': 1e-08, 'final_step': 1.862645149230957e-09}

Lasso

ls = GridSearchCV(Lasso(), param_grid = {"alpha": np.logspace(-3,0)}).fit(X,y)
ls.best_estimator_
Lasso(alpha=np.float64(0.10481131341546852))
np.r_[ls.best_estimator_.intercept_, ls.best_estimator_.coef_]
array([ 3.0555, -0.    , -0.    ,  0.    ,  9.5849, 43.312 ,  0.    ,
        0.    ,  0.    ,  0.    , 34.3712,  9.2233, -0.    ,  0.    ,
       -0.    ,  0.    , -0.    , -0.    ,  0.    , -0.    , -0.    ])
def jax_lasso(X, y, beta, alpha):
  n = X.shape[0]
  Xb = jnp.c_[jnp.ones(n), X]
  ls_loss = (1/(2*n))*jnp.sum((y - Xb @ beta)**2)
  coef_loss = alpha * jnp.sum(jnp.abs(beta[1:]))
  return ls_loss + coef_loss

grad_desc_jax(
  np.zeros(X.shape[1]+1), 
  lambda beta: jax_lasso(X, y, beta, ls.best_estimator_.alpha), 
  step = 1, tau = 0.5, tol=1e-10
)
{'x': Array([ 3.0623,  0.    , -0.    ,  0.    ,  9.5842, 43.303 ,  0.    ,
       -0.    ,  0.0031,  0.0098, 34.3752,  9.2158,  0.    ,  0.    ,
       -0.    ,  0.    , -0.0144, -0.0124,  0.0191, -0.    , -0.    ],      dtype=float32), 'n_iter': 44, 'converged': True, 'final_tol': 1e-10, 'final_step': 1.4551915228366852e-11}

Limitation of gradient descent

  • GD is deterministic

  • GD finds local minima

  • GD is sensitive to starting location

  • GD can be computationally expensive because Gradients are often computationally expensive to calculate

  • GD is sensitive to choices of learning rates

  • GD treats all directions in parameter space uniformly

Newton’s Method

Newton’s Method in 1d

Lets assume we have a 1d function \(f(x)\) we are trying to optimize, our current guess is \(x\) and we want to know how to generate our next step \(\Delta x\).

We start by constructing the 2nd order Taylor approximation of our function at \(x+\Delta x\),

\[ f(x + \Delta x) \approx \widehat{f}(x + \Delta x) = f(x) + \Delta x f'(x) + \frac{1}{2} (\Delta x)^2 f''(x) \]

Finding the Newton step

Our optimal step then becomes the value of \(\Delta x\) that minimizes the quadratic given by our Taylor approximation.

\[ \frac{\partial}{\partial \Delta x} \widehat{f}(x+\Delta x) = 0 \\ \frac{\partial}{\partial \Delta x} \left(f(x) + \Delta x f'(x) + \frac{1}{2} (\Delta x)^2 f''(x) \right) = 0\\ f'(x) + \Delta x f''(x) = 0\\ \Delta x = -\frac{f'(x)}{f''(x)} \]

which suggests an iterative update rule of

\[ x_{k+1} = x_{k} -\frac{f'(x_k)}{f''(x_k)} \]

Generalizing to \(n\)d

Based on the same argument we can see the follow result for a function in \(\mathbb{R}^n\),

\[ f(\boldsymbol{x} + \boldsymbol{\Delta x}) \approx \widehat{f}(\boldsymbol{x}) = f(\boldsymbol{x}) + \boldsymbol{\Delta x}^T \nabla f(\boldsymbol{x}) + \frac{1}{2} \boldsymbol{\Delta x}^T \, \nabla^2 f(\boldsymbol{x}) \,\boldsymbol{\Delta x} \]

then

\[ \frac{\partial}{\partial \boldsymbol{\Delta x}} \widehat{f}(\boldsymbol{x}) = 0 \\ \nabla f(\boldsymbol{x}) + \nabla^2 f(\boldsymbol{x}) \, \boldsymbol{\Delta x} = 0\\ \boldsymbol{\Delta x} = -\left(\nabla^2 f(\boldsymbol{x})\right)^{-1} \nabla f(\boldsymbol{x}) \]

where

  • \(\nabla f(\boldsymbol{x})\) is the \(n \times 1\) gradient vector

  • and \(\nabla^2 f(\boldsymbol{x})\) is the \(n \times n\) Hessian matrix.

based on these results our \(n\)d update rule is

\[ \boldsymbol{x}_{k+1} = \boldsymbol{x}_{k} - (\nabla^2 f(\boldsymbol{x}_k))^{-1} \, \nabla f(\boldsymbol{x}_k) \]

Implementation

def newtons_method(x, f, grad, hess, max_iter=100, tol=1e-8):
    x = np.array(x)
    s = x.shape
    res = {"x": [x], "f": [f(x)]}
    
    for i in range(max_iter):
      x = x - np.linalg.solve(hess(x), grad(x))
      x = x.reshape(s)

      if np.sqrt(np.sum((x - res["x"][-1])**2)) < tol:
        break

      res["x"].append(x)
      res["f"].append(f(x))
    
    return res

A basic example

\[ \begin{aligned} f(x) &= x^2 \\ \nabla f(x) &= 2x \\ \nabla^2 f(x) &= 2 \end{aligned} \]

f = lambda x: np.array(x**2)
grad = lambda x: np.array([2*x])
hess = lambda x: np.array([[2]])
opt = newtons_method(-2., f, grad, hess)
plot_1d_traj( (-2, 2), f, opt )

opt = newtons_method(1., f, grad, hess)
plot_1d_traj( (-2, 2), f, opt )

1d Cubic

\[ \begin{aligned} f(x) &= x^4 + x^3 -x^2 - x \\ \nabla f(x) &= 4x^3 + 3x^2 - 2x - 1 \\ \nabla&2 f(x) &= 12x^2 + 6x^ - 2 \end{aligned} \]

f = lambda x: x**4 + x**3 - x**2 - x 
grad = lambda x: np.array([4*x**3 + 3*x**2 - 2*x - 1])
hess = lambda x: np.array([[12*x**2 + 6*x - 2]])
opt = newtons_method(-1.5, f, grad, hess)
plot_1d_traj( (-1.5, 1.5), f, opt )

opt = newtons_method(1.5, f, grad, hess)
plot_1d_traj( (-1.5, 1.5), f, opt)

2d quadratic cost function

f, grad, hess = mk_quad(0.7)
opt = newtons_method((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

f, grad, hess = mk_quad(0.05)
opt = newtons_method((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

Rosenbrock function

f, grad, hess = mk_rosenbrock()
opt = newtons_method((1.6, 1.1), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = newtons_method((-0.5, 0), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Damped / backtracking implementation

def newtons_method_damped(
  x, f, grad, hess, max_iter=100, max_back=10, tol=1e-8,
  alpha=0.5, beta=0.75
):
    res = {"x": [x], "f": [f(x)]}
    
    for i in range(max_iter):
      grad_f = grad(x)
      step = - np.linalg.solve(hess(x), grad_f) 
      t = 1
      for j in range(max_back):
        if f(x+t*step) < f(x) + alpha * t * grad_f @ step:
          break
        t = t * beta
      
      x = x + t * step

      if np.sqrt(np.sum((x - res["x"][-1])**2)) < tol:
        break

      res["x"].append(x)
      res["f"].append(f(x))
    
    return res

2d quadratic cost function

f, grad, hess = mk_quad(0.7)
opt = newtons_method_damped((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

f, grad, hess = mk_quad(0.05)
opt = newtons_method_damped((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, traj=opt)

Rosenbrock function

f, grad, hess = mk_rosenbrock()
opt = newtons_method_damped((1.6, 1.1), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = newtons_method_damped((-0.5, 0), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

Conjugate gradients

Conjugate gradients

is a general approach for solving a system of linear equations with the form \(\boldsymbol{A} \boldsymbol{x}=\boldsymbol{b}\) where \(\boldsymbol{A}\) is an \(n \times n\) symmetric positive definite matrix and \(\boldsymbol{b}\) is \(n \times 1\) with \(\boldsymbol{x}\) the unknown vector of interest.


This type of problem can also be expressed as a quadratic minimization problem of the form,

\[ \underset{x}{\text{arg min}} \; f(x) = \underset{x}{\text{arg min}} \; \frac{1}{2} \boldsymbol{x}^T \, \boldsymbol{A} \, \boldsymbol{x} - \boldsymbol{x}^T \boldsymbol{b} \]

since the solution is given by

\[ \nabla f(\boldsymbol{x}) = \boldsymbol{A}\boldsymbol{x} - \boldsymbol{b} = 0 \]

A? Conjugate?

Taking things one step further we can also see that the matrix \(\boldsymbol{A}\) is given by the hessian of \(f(x)\)

\[ \nabla^2 f(x) = \boldsymbol{A} \]


Additionally recall, two non-zero vectors \(\boldsymbol{u}\) and \(\boldsymbol{v}\) are conjugate with respect to \(\boldsymbol{A}\) if

\[\boldsymbol{u}^{T} \boldsymbol{A} \boldsymbol{v} = 0\]

Our goal then is to find \(n\) conjugate vectors \(P = \{\boldsymbol{p}_1, \ldots, \boldsymbol{p}_n\}\) to use as “optimal” step directions to traverse our objective function.

Algorithm Sketch

For the \(k\)th step:

  • Define the residual \(\boldsymbol{r}_k = \boldsymbol{b} - \boldsymbol{A} \boldsymbol{x}_k\)

  • Define conjugate vector \(\boldsymbol{p}_k\) using current residual and all previous search directions \[ \boldsymbol{p}_k = \boldsymbol{r}_k - \sum_{i<k} \frac{\boldsymbol{r}_k^T \boldsymbol{A} \boldsymbol{p}_i}{\boldsymbol{p}_i^T \boldsymbol{A} \boldsymbol{p}_i} \boldsymbol{p}_i \]

  • Define step size \(\alpha_k\) using \[ \alpha_k = \frac{\boldsymbol{p}_k^T\boldsymbol{r}_k}{\boldsymbol{p}_k^T \boldsymbol{A} \boldsymbol{p}_k} \]

  • Update \(\boldsymbol{x}_{k+1} = \boldsymbol{x}_k + \alpha_k \boldsymbol{p}_k\)

Algorithm in practice

Given \(x_0\) we set the following initial values, \[ \begin{align} r_0 &= \nabla f(x_0) \\ p_0 &= -r_0 \\ k &= 0 \end{align} \]

while \(\|r_k\|_2 > \text{tol}\), \[ \begin{align} \alpha_k &= \frac{r_k^T \, p_k}{p_k^T \, \nabla^2 f(x_k) \, p_k} \\ x_{k+1} &= x_k + \alpha_k \, p_k \\ r_{k+1} &= \nabla f(x_{k+1}) \\ \beta_{k+1} &= \frac{ r^T_{k+1} \, \nabla^2 f(x_k) \, p_{k} }{p_k^T \, \nabla^2 f(x_k) \, p_k} \\ p_{k+1} &= -r_{k+1} + \beta_{k} \, p_k \\ k &= k+1 \end{align} \]

def conjugate_gradient(x, f, grad, hess, max_iter=100, tol=1e-8):
    res = {"x": [x], "f": [f]}
    r = grad(x)
    p = -r
    
    for i in range(max_iter):
      H = hess(x)
      a = - r.T @ p / (p.T @ H @ p)
      x = x + a * p
      r = grad(x)
      b = (r.T @ H @ p) / (p.T @ H @ p)
      p = -r + b * p
      
      if np.sqrt(np.sum(r**2)) < tol:
        break

      res["x"].append(x) 
      res["f"].append(f(x))
  
    return res

Trajectory

f, grad, hess = mk_quad(0.7)
opt = conjugate_gradient((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, title="$\\epsilon=0.7$", traj=opt)

f, grad, hess = mk_quad(0.05)
opt = conjugate_gradient((1.6, 1.1), f, grad, hess)
plot_2d_traj((-1,2), (-1,2), f, title="$\\epsilon=0.05$", traj=opt)

Rosenbrock’s function

f, grad, hess = mk_rosenbrock()
opt = conjugate_gradient((1.6, 1.1), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)

f, grad, hess = mk_rosenbrock()
opt = conjugate_gradient((-0.5, 0), f, grad, hess)
plot_2d_traj((-2,2), (-2,2), f, traj=opt)