Lecture 15
array([ -36.2252, 9.6357, 66.4583, 48.9574, 24.1885, -13.2444,
18.1455, -135.047 , 116.5772, 60.2524, 30.9319, 107.148 ,
21.6209, 66.2401, -132.8878, 58.636 , 22.186 , 60.3852,
-85.0383, 55.1704, -31.3817, -57.0697, 67.3215, 2.878 ,
-29.5613, -41.3973, -30.3048, -41.5597, 52.7531, -63.5633,
3.5671, 63.712 , 9.9833, 78.4881, -76.126 , 13.4331,
122.6162, 79.0354, 91.2171, 48.7344, 103.6366, 52.5964,
35.0064, -65.8423, -47.3045, -25.6876, 1.8359, 35.4113,
28.0687, 56.3528, 3.6755, -72.3309, 57.143 , -16.9438,
54.1445, 72.6828, -5.0538, -180.6135, -44.6205, 9.2071,
-5.5324, -29.6013, 135.3656, 114.241 , -97.4878, 15.0648,
14.7958, 71.503 , -4.6583, -36.791 , -5.3845, -119.8073,
11.174 , 36.3008, 82.5499, -20.0869, 14.7146, -59.0765,
39.4171, 48.4013, -61.9613, -5.6247, 103.2374, 41.2613,
-129.2273, 10.5113, 32.4936, 78.6921, 5.2956, 64.4473,
88.8358, 39.4851, -11.4866, -52.5082, 112.7248, -9.7006,
13.8393, -36.4004, 68.4865, 19.5335, -75.447 , -87.9538,
79.4784, -75.094 , 25.6229, 84.9034, 71.2779, -66.4093,
77.6444, 40.8875, 31.3165, -22.7143, 84.562 , 6.8075,
9.778 , -65.9149, 106.6952, -3.1901, 41.1555, 32.6265,
-36.5738, 38.9966, -78.664 , -56.0434, 2.9191, 42.6286,
51.3644, -21.8072, -21.9779, -15.7102, -23.5586, 1.3801,
20.4269, 55.7188, -45.6388, -55.1542, 74.6067, -7.2716,
-31.1045, 48.1571, 14.7487, 41.6956, -59.6062, -33.0811,
81.0177, -9.4896, 164.1317, 25.3507, 6.0141, 46.3718,
84.2983, -63.2593, -17.4733, -26.2977, -56.4681, 17.003 ,
53.1867, -94.5398, -18.2541, -49.343 , 40.8724, -90.5986,
27.9392, 41.7287, 49.8082, -9.6384, -66.7551, 122.9159,
-41.3566, -98.6863, -45.0718, 9.9327, -22.0927, 10.6199,
-12.2831, 7.4184, 57.6091, -27.3456, -36.4045, -51.659 ,
28.8175, -23.9402, -51.0637, 4.3618, 10.8402, -11.087 ,
-29.9801, 113.6633, 66.5601, 1.3808, -19.4875, 40.812 ,
43.0652, 35.4802, 77.0732, -49.7352, 65.7192, 73.8539,
-59.4116, 72.9501])
array([[-0.6465, 2.0803, 0.1412, -0.8419, -0.1595, 1.3321, -0.4262,
-0.0351, -0.1938, -0.6093, -0.3433, 0.6126, 0.3777, -1.2062,
-0.2277, -0.8896, -0.4674, -1.3566, 1.4989, -0.7468],
[-0.3834, -0.3631, -1.2196, 0.6 , 0.3315, 1.1056, 0.2662,
-0.7239, 0.0259, -0.2172, -0.6841, 0.0991, 0.2794, -1.208 ,
-0.7818, -1.7348, -1.3397, -0.5723, -0.5882, 0.2717],
[-0.1637, -0.8118, 0.9551, 0.5711, 0.8719, -0.9619, 1.9846,
-1.1806, -1.1261, 0.297 , 1.2499, 0.7109, -0.1183, 0.6708,
0.6895, 1.4705, 0.0634, -0.3079, -2.2512, -0.0216],
[-0.9292, -0.4897, -2.1196, -1.142 , 1.266 , -0.2988, 1.0016,
-2.1969, -1.0739, -0.1149, 0.5122, 0.302 , -0.0974, 1.3461,
0.1909, 1.1223, 0.6268, 2.2035, -0.5135, 2.0118],
[ 0.1645, -0.5847, 0.2708, -3.5635, 0.1526, 0.5283, 0.7674,
1.392 , -0.0819, 1.3211, 0.4644, -1.0279, 0.9849, -1.069 ,
-0.4301, 0.0798, -0.5119, -0.3448, 0.8166, -0.4 ],
[ 0.4134, 1.9511, -0.5013, -1.4894, 0.4191, -1.4104, 0.2617,
-0.6981, 0.0368, -1.151 , 2.0752, 0.5001, -0.2428, 0.45 ,
0.7176, 1.3846, 0.5155, 0.4459, -0.2784, -0.2864],
[-0.0628, -1.424 , -1.1023, 0.1445, -0.4836, 1.4795, -0.5921,
1.6423, -0.5013, 0.4435, 2.0044, 0.6221, 0.0747, -1.4117,
-0.202 , -1.3071, -0.8656, -1.311 , 0.0424, 0.7255],
[-0.6642, 1.4317, -0.0658, -0.7379, -0.9153, 0.8653, 0.7143,
1.0912, -1.3773, -2.6022, -0.2955, -0.3985, 0.0918, 0.3851,
0.502 , -0.4665, 1.6432, -0.2438, -0.4943, 1.4753],
[ 1.5247, -1.3419, -0.4453, -0.6141, 2.0632, -1.0742, -1.4419,
-1.4923, 0.3135, 0.7691, 0.5383, -0.9741, 0.8457, -0.0014,
0.3895, 0.2118, 1.0977, -0.4036, -1.5496, -0.3672],
[ 1.5524, -1.1109, -0.5624, -0.9106, -0.0506, 0.8533, -0.5452,
-1.7836, -0.8365, 2.171 , -0.6158, 0.2523, -1.8707, 0.6142,
0.7962, 0.0706, -0.2386, -0.4144, -0.0898, -0.4745],
[-0.7206, -2.0213, 0.0157, -1.191 , -0.3127, 0.2891, 0.8596,
-2.2427, 0.0021, 1.4327, 0.4714, 0.9533, -0.6365, 1.3212,
0.8872, 1.15 , -1.5469, 0.4055, -0.3341, 0.9919],
[ 0.2328, 0.5523, -0.2356, -1.2547, 0.6686, -2.1204, -0.186 ,
1.4915, -1.1353, 2.3889, 0.3449, -0.6703, -0.2358, -2.1923,
-0.4635, -0.9962, -0.1116, 0.0605, 0.0027, -1.439 ],
[ 1.8092, -1.5857, -0.9765, 1.6171, 0.368 , -0.2947, 1.5897,
-0.8878, 1.0547, -0.0427, -1.187 , 0.7605, 1.2381, -0.5014,
1.0201, -0.5773, -0.632 , -0.502 , -1.6914, 0.803 ],
[ 0.1935, 0.5289, -0.7559, -0.1047, -0.3334, -1.0275, 1.0327,
-0.8811, 0.0483, 1.8504, 1.5727, 0.3325, -1.7398, -0.2383,
-0.4967, 0.3939, 1.9322, 0.062 , -1.1205, -0.95 ],
[ 0.6914, -2.0308, 1.1554, -0.4219, -1.6257, -0.1138, -0.9225,
-1.9216, 1.2995, -1.5084, -0.863 , 0.2528, 1.3636, 0.2059,
0.0381, 1.1124, 1.73 , 0.4496, -0.1806, 0.7681],
[-0.0862, -0.2131, -0.5343, -0.1066, -0.8403, 1.3862, 0.5885,
-1.089 , -0.8571, 2.0178, 2.6078, -0.5807, -0.3466, -0.5166,
-0.7863, 0.2918, -0.1904, -0.8012, -1.6868, 0.2538],
[-1.0714, -0.4582, 0.4255, 0.5657, -0.1743, 2.0978, -0.8453,
-0.9807, -0.0414, 0.5851, 0.2645, -0.3602, 0.4151, 1.2829,
-0.0485, -0.4278, 0.2703, 0.821 , -1.338 , 1.4986],
[ 0.2126, 2.1145, -0.1471, 1.7549, 0.9465, -1.3906, -1.0954,
-0.5224, 0.5338, 0.0591, -0.2671, 1.5731, 0.3903, 0.6137,
-0.5277, 0.6306, 0.7467, 1.7232, 0.62 , 2.0249],
[ 0.7085, 1.312 , -0.6134, 0.8665, -1.4706, 0.2597, -0.1606,
-0.7118, 0.2154, -0.7415, -0.608 , -0.3412, 1.0772, 0.4695,
-0.1285, 0.0654, 0.4922, -0.6707, -1.8229, -0.4215],
[ 1.2418, -1.9068, -0.6066, 0.1639, 0.986 , -0.1853, -0.0303,
1.152 , -0.161 , 0.0226, 0.8991, 0.9874, -0.802 , -0.7241,
0.2466, 0.747 , -0.9682, -1.1908, 0.4313, -1.2039],
[-0.6263, 0.2757, 0.9388, 1.3835, -0.5935, 0.4409, -1.4681,
0.0114, -0.3643, -0.3373, -1.3341, 0.0036, 0.5513, -0.1016,
0.6814, -1.4258, -1.3869, -2.0679, -1.6482, 1.0062],
[-0.2397, 0.6481, 1.7758, 0.0166, -1.7724, -0.1862, 1.118 ,
-0.8409, 0.6136, 0.5269, -0.2908, -0.2294, 0.1747, -0.3881,
-0.2667, -0.7601, 0.4313, -0.7488, -0.7594, -0.4084],
[-1.0989, 1.1887, -0.5288, 1.6782, 0.3827, 0.4309, -1.3949,
0.6801, -1.2572, 0.6585, 0.7674, -1.5397, 1.1786, 1.2429,
-1.1094, 1.2524, -0.7556, 0.4051, -0.3198, 0.6704],
[ 0.0582, 0.1247, 0.1058, -0.4947, -0.1381, 1.3226, 0.3375,
0.0445, 1.2923, 0.67 , -1.3132, -0.7997, -0.1669, 1.5938,
-0.7805, -0.3689, -2.5977, -1.2921, -1.2897, -0.074 ],
[-0.9515, -1.0973, 1.5675, 0.0103, -1.1347, 0.165 , 0.0289,
-0.6242, -1.3193, 0.2246, 0.7557, -0.9032, 2.1041, -0.6316,
-0.1271, -0.4006, -0.8671, -0.5601, -0.0713, -1.1371],
[ 0.4599, 0.5513, 1.6362, -1.2392, -0.3352, 1.0237, 1.7626,
-0.5441, 1.3217, -1.2237, 2.5112, -1.7501, -0.0857, 0.8239,
-0.6406, -1.05 , -0.635 , -2.1445, 1.4129, 0.2546],
[ 0.1871, -2.2206, 1.2475, 1.2345, -1.5021, 1.1434, -1.0406,
0.0709, 1.2826, 0.5946, -0.176 , 0.0639, -1.4364, -0.3326,
-0.4648, 0.0733, -1.5075, 0.7799, -0.6549, 0.2562],
[ 0.084 , 0.9564, 0.366 , -0.6843, -0.6239, 0.3233, -0.4753,
-0.7024, -0.8606, -0.8089, 1.7968, -0.9079, -0.1103, -0.8212,
1.328 , -1.2039, -2.1219, -0.8672, -1.345 , -1.0769],
[-0.807 , -0.037 , 0.7597, 0.9556, 0.1334, -0.0225, 1.9088,
-0.423 , 0.267 , 0.7138, 0.996 , 0.0679, 0.1559, 0.1314,
-0.342 , 0.1817, 0.4344, 1.383 , -0.1708, 0.2745],
[-0.3675, -0.93 , 1.2117, 1.0203, -1.1554, -0.0461, 0.827 ,
1.793 , -1.0029, -0.7901, 0.0797, 0.992 , -0.5725, 1.3592,
1.2639, 1.3791, 0.021 , -2.5727, -0.2494, 2.0499],
...,
[ 1.6243, -1.2413, -0.4177, 0.2389, -0.2734, -0.6785, -1.0147,
-0.2772, -1.711 , -1.1543, 0.2933, 1.487 , 0.7526, 0.6561,
0.4132, 0.1095, 0.1406, -0.6598, 1.2687, 1.2148],
[ 0.6854, -0.7399, -1.0681, 0.2991, 0.0382, -0.9321, -0.8341,
0.0215, 0.0612, -0.129 , 0.8795, -0.1681, 0.8851, 1.2921,
0.3478, 1.5717, 2.4181, -0.0638, 1.3938, 0.884 ],
[ 2.2326, -1.7645, 1.9779, -1.6875, -0.8401, 0.1057, 1.1688,
0.3301, -0.5216, 1.207 , -1.5042, 1.6341, -1.0896, -0.7015,
-1.7587, 1.4814, 0.6081, -0.7485, 2.1342, -0.4016],
[ 2.2171, -0.6177, 0.1949, -1.0798, 0.586 , -0.859 , 2.5508,
-0.8039, 0.1503, -0.1069, -0.6496, -0.2479, 0.1649, 0.765 ,
0.8986, -0.3648, 0.6722, -0.2408, 0.7112, 0.1551],
[ 1.1797, 2.0229, 0.2965, -0.4986, 0.6617, 0.8841, -0.8252,
-0.3799, -1.1173, -1.3918, 0.9206, -0.076 , -0.8812, -1.7954,
0.2918, 0.7677, 0.4183, 1.236 , -0.1036, -0.0952],
[-1.4325, 1.5241, 0.4914, 0.4466, -1.5217, -0.5697, -0.1623,
-0.0357, -1.3161, 1.733 , 0.7872, 1.4468, 1.8372, 1.0749,
-2.0308, -0.2996, -1.1323, 0.6271, 0.7217, 0.8836],
[ 0.4932, -1.2439, 0.5748, -0.1409, 1.7359, -1.1483, -0.4902,
-0.5052, -0.4267, -0.6533, 0.242 , 0.7283, 0.2963, -1.2347,
0.6998, 0.7025, -0.5894, -2.7557, -1.1078, -0.5546],
[ 0.7622, -1.2367, -0.9891, 1.7989, 0.1187, -1.8608, -0.559 ,
0.775 , 0.0616, -1.6046, 0.2385, -1.5886, -0.1833, -0.7817,
1.8364, -0.5933, -0.3687, 0.3881, 1.2738, 1.2086],
[-0.7038, -1.2434, 0.5626, 0.3224, 0.3713, 0.7815, 1.957 ,
0.4423, 0.6326, -1.956 , 1.1085, 0.1665, 0.841 , 2.1472,
0.5566, -0.2651, -0.9084, 2.0134, 0.3486, 1.2223],
[-0.8597, -1.2391, 0.9525, -0.7438, -0.9162, 0.1223, 0.6288,
0.9881, -0.223 , -0.3202, 0.5368, -0.9382, 0.1865, -1.4094,
0.226 , -0.0726, 1.423 , 2.1237, 0.1397, -0.5506],
[-0.0956, 0.2007, -0.3942, 0.812 , 0.6777, -2.5834, -1.3294,
-0.6009, -1.0962, -0.359 , 0.0455, -0.5706, 0.0263, -0.9308,
-0.0649, 0.6586, -0.0469, -1.0283, 0.3524, -0.3676],
[-1.4168, -0.0501, -0.5261, -0.3774, -0.9222, 0.253 , -0.2044,
0.0524, 0.8973, 0.0268, 1.6289, 0.4335, -2.1098, -0.488 ,
0.8635, -1.7442, 0.6374, -1.3713, -0.6505, 0.0552],
[-0.1955, -0.6314, 1.0877, -0.9238, -1.2701, 0.3528, 0.9894,
0.4388, -0.2405, 0.3558, -0.2266, 0.5029, 1.3886, -1.8156,
-0.4634, -0.9616, -0.9101, 0.5856, -0.7043, 1.2456],
[ 0.1021, -0.9809, 0.3537, 1.6762, -0.7037, -0.2284, -0.278 ,
-0.4083, 0.666 , 0.681 , -0.9055, 1.054 , -0.0522, 0.3645,
1.1951, -1.8104, -1.5148, 1.0655, 0.3521, -0.9033],
[ 0.2733, -0.0657, -0.5118, -0.9471, -0.4646, 1.4039, 0.5143,
0.0839, -1.1759, 0.5828, 1.7003, -0.4077, 0.5941, -0.7209,
-0.2797, 0.6193, 1.001 , 1.6007, -0.6754, 0.2706],
[ 1.0588, -2.1785, 0.2624, 0.1752, -0.0675, 0.0093, -1.8534,
-1.7246, 1.5724, -0.5295, 0.4144, 0.6353, -0.7018, -1.1284,
-0.1179, 0.2766, -1.2591, 2.0028, 0.312 , 1.073 ],
[ 0.428 , -0.2036, -1.3212, 1.1555, -0.439 , -1.6271, -0.1925,
0.622 , -1.5984, -0.5565, -0.5401, 0.0358, 0.0133, -1.4583,
2.0069, 0.1263, 0.4161, -0.9847, 0.7966, 0.486 ],
[-0.1173, 0.6381, -0.2812, -0.3578, 1.9299, 2.1058, 2.3049,
1.0974, 0.5727, 0.9601, -0.211 , 0.6298, 0.7401, -1.7709,
-1.0439, 0.8751, -1.5677, 0.3412, -0.1276, -0.8195],
[ 0.2991, -0.7227, -0.4858, 0.2055, 0.8665, 0.5538, 0.3632,
0.3877, 0.9835, 0.313 , 1.5294, -0.3187, 1.8937, 0.3538,
1.0765, 0.0236, -0.2756, 0.0235, 0.1774, -0.6602],
[ 0.2811, 0.94 , 0.2862, 0.2612, -0.3128, 1.8086, -0.1915,
0.5764, 0.5697, 0.8199, -2.1527, -1.3712, -0.6986, 0.3284,
-0.4288, 0.2109, -0.4925, -0.3614, 0.4904, -2.0268],
[-0.1443, -1.0635, -1.0286, -0.5993, -0.0703, 1.1204, -0.0037,
-0.4246, -0.3002, -0.7602, 1.4965, 2.5779, 0.7647, 0.1203,
0.1931, 0.7629, 1.2026, -0.8532, 0.1837, 0.5151],
[ 0.4857, 0.4299, 0.1458, 0.7444, -0.0979, 1.0103, 0.6701,
-0.2955, 0.0965, 1.4683, -1.7762, 0.9793, -0.2762, -0.0483,
0.0385, -0.5835, 0.6415, 0.0514, 0.3048, 1.6422],
[ 0.4843, 0.7492, 1.2246, 0.5665, 0.2853, -1.8185, -0.7811,
-1.2811, -0.1822, 0.5036, 0.2912, -0.4508, -0.468 , 0.0471,
1.3635, 0.8755, 0.3948, 0.6807, -0.2039, -1.7107],
[-1.0209, -0.6634, -0.9163, 1.3832, 0.0976, 1.3228, 0.1802,
0.2148, 1.6929, 0.4813, -0.2126, 0.7982, -1.1308, -0.6505,
0.0008, 1.4194, -0.0373, 0.3867, 0.2306, -0.5345],
[-0.7579, 1.397 , 1.337 , -1.8439, 2.0066, 1.6138, -1.5925,
-0.2433, 0.1118, 0.11 , -0.0658, 0.3186, 0.2924, -0.2974,
1.016 , -0.231 , 1.639 , 0.4316, -0.8798, -0.3389],
[ 0.7059, 1.752 , -0.0646, -0.0381, -0.2057, 0.6298, -0.0253,
-1.205 , -0.1709, -0.6655, -2.0803, 0.4152, -0.1783, 0.8111,
-2.6128, -3.8809, 2.1338, 0.7489, 0.485 , 0.9745],
[-1.1806, -0.3993, -0.2226, -2.3027, 1.3672, -0.1536, -1.9393,
0.2231, -0.0869, 0.9348, -0.4875, -0.8112, 0.6287, 0.2276,
-0.2909, 0.0516, 1.3023, 0.5602, 0.515 , 0.4992],
[-0.6427, -0.3821, 0.1691, 0.1282, 0.4937, 0.4471, -0.2522,
0.8716, -1.5433, 1.6954, -0.9341, -0.5208, -1.0103, 0.9605,
-1.877 , 0.8052, -1.7248, -1.3744, -0.4322, 1.1537],
[ 0.8806, -0.7055, -0.0158, 0.172 , -0.4213, 0.8811, 0.7577,
-0.3878, -0.6382, -1.365 , 0.1341, 1.7316, -0.6366, -0.6532,
-1.4726, 0.8897, -1.32 , 0.7008, -1.2858, 1.1342],
[-0.5563, -1.0968, -0.9132, -1.2012, 1.8192, -1.1012, 0.805 ,
2.0475, 2.122 , 0.0836, -0.1508, -0.6059, 1.1388, -1.2726,
-1.0138, -0.8948, 0.8435, 1.1215, 0.2176, 1.0591]],
shape=(200, 20))
def grad_desc_lm(X, y, beta, step, max_step=50):
X = jnp.c_[jnp.ones(X.shape[0]), X]
n, k = X.shape
f = lambda beta: jnp.sum((y - X @ beta)**2)
grad = lambda beta: 2*X.T @ (X@beta - y)
res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
for i in range(max_step):
beta = beta - grad(beta) * step
res["x"].append(beta)
res["loss"].append(f(beta).item())
res["iter"].append(res["iter"][-1]+1)
return res
Lets take a quick look at the linear regression loss function and gradient descent and think a bit about its cost(s), we can define the loss function and its gradient as follows:
\[ \begin{aligned} f(\underset{k\times 1}{\boldsymbol{\beta}}) &= (\underset{n \times 1}{y} - \underset{n\times k}{\boldsymbol{X}} \, \underset{k \times 1}{\boldsymbol{\beta}})^T (\underset{n \times 1}{y} - \underset{n\times k}{\boldsymbol{X}} \, \underset{k \times 1}{\boldsymbol{\beta}}) \\ \\ \nabla f(\underset{k\times 1}{\boldsymbol{\beta}}) &= \underset{k\times n}{2 \boldsymbol{X}^T}(\underset{n \times k}{\boldsymbol{X}}\,\underset{k\times 1}{\boldsymbol{\beta}} - \underset{n \times 1}{\boldsymbol{y}})\\ %&= %\left[ % \begin{matrix} % 2 \boldsymbol{X}_{\cdot 1}^T(\boldsymbol{X}_{\cdot 1}\boldsymbol{\beta}_1 - \boldsymbol{y}) \\ % 2 \boldsymbol{X}_{\cdot 2}^T(\boldsymbol{X}_{\cdot 2}\boldsymbol{\beta}_2 - \boldsymbol{y}) \\ % \vdots \\ % 2 \boldsymbol{X}_{\cdot k}^T(\boldsymbol{X}_{\cdot k}\boldsymbol{\beta}_k - \boldsymbol{y}) % \end{matrix} %\right] \end{aligned} \]
What are the costs of calculating the loss function and gradient respectively in terms of \(n\) and \(k\)?
Calculating the loss costs \({O}(nk)\)
Calculating the gradient costs \({O}(n^2k)\)
This is a variant of gradient descent where rather than using all \(n\) data points we randomly sample one at a time and use that single point to make our gradient step.
Sampling of observations can be done with or without replacement
Will take more steps to converge but each step is now cheaper to compute
SGD has slower asymptotic convergence than GD, but is often faster in practice in terms of runtime
Generally requires the learning rate to shrink as a function of iteration to guarantee convergence
def sto_grad_desc_lm(X, y, beta, step, max_step=50, seed=1234, replace=True):
X = jnp.c_[jnp.ones(X.shape[0]), X]
f = lambda beta: jnp.sum((y - X @ beta)**2)
grad = lambda beta, i: 2*X[i,:] * (X[i,:]@beta - y[i])
n, k = X.shape
res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
rng = np.random.default_rng(seed)
for i in range(max_step):
if replace:
js = rng.integers(0,n,n)
else:
js = np.array(range(n))
rng.shuffle(js)
for j in js:
beta = beta - grad(beta, j) * step
res["x"].append(beta)
res["loss"].append(f(beta).item())
res["iter"].append(res["iter"][-1]+1)
return res
array([ 3.0616, -0.0121, -0.0096, 0.096 , 9.6955, 43.406 , 0.0253,
0.0284, 0.0962, 0.1069, 34.4884, 9.3445, -0.0165, -0.0147,
-0.0396, 0.0969, -0.1057, -0.0943, 0.11 , -0.0096, -0.0875])
sgd_lm_rep = sto_grad_desc_lm(
X, y, np.zeros(X.shape[1]+1),
step = 0.001, max_step=20, replace=True
)
sgd_lm_rep["x"][-1]
Array([ 3.1032, -0.0554, 0.0073, 0.0165, 9.7246, 43.4316, 0.075 ,
0.0785, 0.1037, 0.0195, 34.4093, 9.3187, -0.0244, 0.0348,
0.0063, 0.0499, -0.0332, -0.0741, 0.1323, -0.091 , -0.1617], dtype=float64)
sgd_lm_worep = sto_grad_desc_lm(
X, y, np.zeros(X.shape[1]+1),
step = 0.001, max_step=20, replace=False
)
sgd_lm_worep["x"][-1]
Array([ 3.0677, -0.0049, -0.0261, 0.0937, 9.7143, 43.3926, 0.0592,
0.038 , 0.1038, 0.0831, 34.4311, 9.3172, -0.0477, 0.0014,
-0.0651, 0.06 , -0.1149, -0.0543, 0.0993, -0.0569, -0.1273], dtype=float64)
Generally, rather than thinking in iterations we use epochs instead - an epoch is one complete pass through the data.
Array([ 3.01 , 0.0118, 0.0029, 0.0033, -0.0014, 0.0028, 0.0252,
-0.0005, 0.0009, 12.2793, 44.4961, 3.6409, 0.0165, 61.3964,
0.0011, 0.0005, 0.011 , -0.0134, -0.0045, 0.0028, 0.0244], dtype=float64)
sgd_lm_rep = sto_grad_desc_lm(
X, y, np.zeros(X.shape[1]+1),
step = 0.001, max_step=3, replace=True
)
sgd_lm_rep["x"][-1]
Array([ 2.9968, 0.0452, 0.0346, 0.0205, -0.0286, -0.0518, -0.0534,
-0.0142, 0.0371, 12.2115, 44.5628, 3.6736, 0.0249, 61.4223,
-0.0094, -0.0877, 0.0134, -0.0324, -0.0178, 0.0222, 0.0089], dtype=float64)
sgd_lm_worep = sto_grad_desc_lm(
X, y, np.zeros(X.shape[1]+1),
step = 0.001, max_step=3, replace=False
)
sgd_lm_worep["x"][-1]
Array([ 2.9756, 0.0149, -0.0011, -0.0047, 0.011 , -0.0176, 0.0251,
-0.0244, -0.0328, 12.2937, 44.5281, 3.6333, -0.0158, 61.4162,
0.0541, 0.0064, -0.0231, -0.0014, -0.0144, -0.0299, 0.0141], dtype=float64)
This is a further variant of stochastic gradient descent where a mini batch of \(m\) data points is selected for each gradient update,
The idea is to find a balance between the cost of increasing the data size vs the speed-up of vectorized calculations.
More updates per epoch than GD, but less than SGD
Mini batch composition can be constructed by sampling data points with or without replacement
def mb_grad_desc_lm(X, y, beta, step, batch_size = 10, max_step=50, seed=1234, replace=True):
X = jnp.c_[jnp.ones(X.shape[0]), X]
f = lambda beta: jnp.sum((y - X @ beta)**2)
grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
n, k = X.shape
res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
rng = np.random.default_rng(seed)
for i in range(max_step):
if replace:
js = rng.integers(0,n,n)
else:
js = np.array(range(n))
rng.shuffle(js)
for j in js.reshape(-1, batch_size):
beta = beta - grad(beta, j) * step
res["x"].append(beta)
res["loss"].append(f(beta).item())
res["iter"].append(res["iter"][-1]+1)
return res
array([ 3.0081, 0.0088, 0.0002, 0.0021, 0.0037, 0.0033, 0.026 ,
-0.0006, 0.0005, 12.2771, 44.4939, 3.6423, 0.0168, 61.3938,
-0.0012, -0.0056, 0.014 , -0.0093, -0.0056, 0.0024, 0.0217])
Batch size: 10
[ 2.9754 0.0154 0.0004 -0.0038 0.0118 -0.0171 0.0248 -0.0242 -0.0336
12.2937 44.5285 3.6334 -0.0156 61.417 0.0546 0.0075 -0.023 -0.0004
-0.0135 -0.031 0.0138]
Batch size: 50
[ 2.9761 0.0107 -0.001 -0.0029 0.0119 -0.0161 0.0238 -0.0243 -0.0374
12.2943 44.5304 3.6337 -0.0172 61.4199 0.0557 0.0068 -0.0246 -0.0015
-0.0134 -0.0326 0.0127]
Batch size: 100
[ 2.973 0.0094 -0.0001 -0.0038 0.0123 -0.0171 0.0223 -0.0263 -0.0416
12.2923 44.5302 3.635 -0.0155 61.4176 0.0565 0.009 -0.0258 -0.002
-0.014 -0.0353 0.0045]
We’ve talked a bit about the computational side of things, but why do these approaches work at all?
In statistics and machine learning many of our problems have a form that looks like,
\[ \underset{\theta}{\text{arg min}} \; \ell(\boldsymbol{X}, \theta) = \underset{\theta}{\text{arg min}} \; \frac{1}{n} \sum_{i=1}^n \ell(\boldsymbol{X}_i, \theta) \]
which means that the gradient of the loss function is given by
\[ \nabla \ell(\boldsymbol{X}, \theta) = \frac{1}{n} \sum_{i=1}^n \nabla \ell(\boldsymbol{X}_i, \theta) \]
\[ \nabla \ell(\boldsymbol{X}, \theta) \approx \frac{1}{|B|} \sum_{i \in B}^n \nabla \ell(\boldsymbol{X}_i, \theta) \]
Because we are sampling \(B\) randomly, then our SGD and mini batch GD approximations are unbiased estimated of the full gradient,
\[ E\left[ \frac{1}{|B|} \sum_{i \in B}^n \nabla \ell(\boldsymbol{X}_i, \theta) \right] = \frac{1}{n} \sum_{i=1}^n \nabla \ell(\boldsymbol{X}_i, \theta) = \nabla \ell(\boldsymbol{X}, \theta) \]
Each update can be viewed as a noisy gradient descent step (gradient + zero mean noise).
As mentioned previously we need to be a bit careful with learning rates and convergence for both of these methods. So far, our approach has been naive and runs for a fixed number of epochs.
If we want to use a convergence criterion we need to keep the following in mind:
Let \(\theta^*\) be a global / local minimizer of our loss function \(\ell(\boldsymbol{X},\theta)\), then by definition \(\nabla \ell(\boldsymbol{X},\theta^*) = 0\)
The issue is that our gradient approximation, \[ \frac{1}{|B|} \sum_{i \in B}^n \nabla \ell(\boldsymbol{X}_i, \theta) \ne 0 \] as \(B\) is a subset of the data, therefore our algorithm will keep taking steps / never converge.
The practical solution to this is to implement a learning rate schedule which generally shrink the learning rate / step size over time to ensure convergence.
The choice of the exact learning schedule is problem specific, and is usually about finding the balance of how quickly to shrink the step size.
Some common examples:
Piecewise constant - \(\eta_t = \eta_i \text{ if } t_i \leq t \leq t_{i+1}\)
Exponential decay - \(\eta_t = \eta_0 e^{-\lambda t}\)
Polynomial decay - \(\eta_t = \eta_0 (\beta t+1)^{-\alpha}\)
There are many more approaches including more exotic techniques that allow the learning rate to increase and decrease to help the optimizer better explore the objective function and in some cases escape local optima.
This approach was proposed in by Duchi, Hazan, & Singer in 2011 and is based on the idea of scaling the learning rates for the current step by the sum of the square gradients of previous steps - this has the effect of shrinking the step size of dimensions with large previous gradients.
\[ \begin{aligned} \boldsymbol{\theta}_{t+1} &= \boldsymbol{\theta}_t - \eta_t \frac{1}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_t)\\ \boldsymbol{s}_t &= \sum_{i=1}^t \left(\nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_i)\right)^2 \end{aligned} \]
here \(\epsilon\) is a small constant (i.e. \(10^{-7}\)) to avoid division by zero.
def adagrad_lm(X, y, beta, step, batch_size = 10, max_step=50, seed=1234, replace=True, eps=1e-8):
X = jnp.c_[jnp.ones(X.shape[0]), X]
f = lambda beta: jnp.sum((y - X @ beta)**2)
grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
n, k = X.shape
res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
rng = np.random.default_rng(seed)
S = np.zeros(k)
for i in range(max_step):
if replace:
js = rng.integers(0,n,n)
else:
js = np.array(range(n))
rng.shuffle(js)
for j in js.reshape(-1, batch_size):
G = grad(beta, j)
S += G**2
beta = beta - step * (1/np.sqrt(S + eps)) * G
res["x"].append(beta)
res["loss"].append(f(beta).item())
res["iter"].append(res["iter"][-1]+1)
return res
AdaGrad - SGD
[ 3.1134 0.0657 0.1134 -0.0746 -0.0023 -0.0545 -0.0544 30.8796 -0.1221
0.0449 -0.1069 0.0407 40.5307 0.0046 44.6237 -0.0401 72.2858 0.0184
-0.0153 0.0663 0.0114]
AdaGrad - MBGD (25)
[ 3.0676 0.0346 0.0862 -0.0847 0.0706 -0.0527 0.0161 30.9695 -0.0555
0.0156 -0.0922 -0.0061 40.5614 -0.0126 44.6231 -0.0221 72.2704 -0.0215
-0.0182 0.0601 0.0787]
AdaGrad - MBGD (50)
[ 3.0519 0.0268 0.0942 -0.0765 0.0593 -0.0484 0.0193 30.9786 -0.0188
0.0246 -0.0656 -0.0208 40.5611 0.0059 44.6133 -0.0297 72.2234 -0.0276
-0.0198 0.0425 0.0926]
AdaGrad - GD
[ 2.6765 1.1926 -0.7393 1.2202 -0.0051 1.5508 0.3784 27.3744 -1.4125
3.0303 -0.2343 2.4958 30.5815 2.4433 32.5644 1.4608 35.3517 -0.7034
-0.5997 -0.6608 -1.1752]
With AdaGrad the denominator involving \(\boldsymbol{s}_t\) gets larger as \(t\) increases, but in some cases it gets too large too fast to effectively explore the loss function. An alternative is to use a moving average of the past squared gradients instead.
RMSProp replaces AdaGrad’s \(\boldsymbol{s}_t\) with the following,
\[ \boldsymbol{s}_t = \beta \, \boldsymbol{s}_{t-1} + (1-\beta) \, (\nabla \ell(\boldsymbol{X},\boldsymbol{\theta}_t))^2 \\ \boldsymbol{s}_0 = \boldsymbol{0} \]
in practice a value of \(\beta \approx 0.9\) is often used.
def rmsprop_lm(X, y, beta, step, batch_size = 10, max_step=50, seed=1234, replace=True, eps=1e-8, b=0.9):
X = jnp.c_[jnp.ones(X.shape[0]), X]
f = lambda beta: jnp.sum((y - X @ beta)**2)
grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
n, k = X.shape
res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
rng = np.random.default_rng(seed)
S = np.zeros(k)
for i in range(max_step):
if replace:
js = rng.integers(0,n,n)
else:
js = np.array(range(n))
rng.shuffle(js)
for j in js.reshape(-1, batch_size):
G = grad(beta, j)
S = b*S + (1-b) * G**2
beta = beta - step * (1/np.sqrt(S + eps)) * G
res["x"].append(beta)
res["loss"].append(f(beta).item())
res["iter"].append(res["iter"][-1]+1)
return res
RMSProp - SGD
[ 2.9718 -0.0486 0.1521 0.0179 0.1236 -0.0449 -0.0614 30.8114 0.0474
0.074 -0.0083 -0.0569 40.4583 -0.0975 44.679 -0.1505 72.2751 -0.0066
-0.0618 -0.1117 -0.0201]
RMSProp - MBGD (25)
[ 2.9707 -0.2066 0.2688 0.0507 0.1513 0.0343 -0.0799 30.7114 0.0451
0.1576 -0.0103 -0.0894 40.338 -0.1608 44.6958 -0.2083 72.2931 -0.143
-0.0646 -0.0814 -0.0219]
RMSProp - MBGD (50)
[ 3.0954 -0.181 0.4018 0.2723 0.2013 -0.0766 -0.0231 30.7481 0.1931
0.4516 -0.0279 -0.1569 40.3096 -0.2517 44.6712 -0.349 72.3357 -0.1481
-0.0044 -0.1344 0.0849]
RMSProp - GD
[ 2.2162 -0.6382 -2.1805 2.4839 0.9998 0.9919 0.0923 25.8598 -1.7242
1.7694 0.0141 3.112 28.4576 -1.6735 29.3372 0.484 31.0173 0.0241
1.0046 -1.2184 -2.8247]
Rather then just using the gradient information at our current location it may be benefitial to use information from our previous steps as well. A general setup for this type approach looks like,
\[ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta \, \boldsymbol{m}_t \\ \boldsymbol{m}_t = \beta \, \boldsymbol{m}_{t-1} + (1-\beta) \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_t) \]
where \(\eta\) is our step size and \(\beta\) determines the weighting of the current gradient and the previous gradients.
If you have taken a course on time series, this has a flavor that looks a lot like moving average models,
\[ \boldsymbol{m}_t = (1-\beta) \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_t) + \beta(1-\beta) \, \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_{t-1}) + \beta^2(1-\beta) \, \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_{t-2}) + \cdots \]
The “adaptive moment estimation” algorithm is a combination of momentum with RMSProp,
\[ \begin{aligned} \theta_{t+1} &= \theta_t - \eta_t \frac{\boldsymbol{m_t}}{\sqrt{\boldsymbol{s}_t + \epsilon}} \\ \boldsymbol{m}_t &= \beta_1 \, \boldsymbol{m}_{t-1} + (1-\beta_1) \, \nabla \ell(\boldsymbol{X}, \theta_t) \\ \boldsymbol{s}_t &= \beta_2 \, \boldsymbol{s}_{t-1} + (1-\beta_2) \, (\nabla \ell(\boldsymbol{X},\boldsymbol{\theta}_t))^2 \\ \end{aligned} \]
Note that RMSProp is a special case of Adam when \(\beta_1 = 0\).
Adam is widely used in practice is and is commonly available within tools like Torch for fitting NN models.
In typical use \(\beta_1=0.9\), \(\beta_2=0.999\), \(\epsilon=10^{-6}\), and \(\eta_t=0.001\) are used. As the learning rate is not guaranteed to decrease over time, the algorithm is not guaranteed to converge.
One small alteration that was suggested by the original others and is commonly used is to correct for the bias towards small values in the initial estimates of \(\boldsymbol{m}_t\) and \(\boldsymbol{s}_t\). In which case they are replaced with,
\[ \begin{aligned} {\hat{\boldsymbol{m}}}_t &= \boldsymbol{m}_t / (1-{\beta_1}^t) \\ {\hat{\boldsymbol{s}}}_t &=\boldsymbol{s}_t / (1-{\beta_2}^t) \\ \end{aligned} \]
def adam_lm(X, y, beta, step=0.001, batch_size = 10, max_step=50, seed=1234, replace=True, eps=1e-6, b1=0.9, b2=0.999):
X = jnp.c_[jnp.ones(X.shape[0]), X]
f = lambda beta: jnp.sum((y - X @ beta)**2)
grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
n, k = X.shape
res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
rng = np.random.default_rng(seed)
S = np.zeros(k)
M = np.zeros(k)
t = 0
for i in range(max_step):
if replace:
js = rng.integers(0,n,n)
else:
js = np.array(range(n))
rng.shuffle(js)
for j in js.reshape(-1, batch_size):
t += 1
G = grad(beta, j)
S = b2*S + (1-b2) * G**2
M = b1*M + (1-b1) * G
M_hat = M / (1-b1**t)
S_hat = S / (1-b2**t)
beta = beta - step * (M_hat / np.sqrt(S_hat + eps))
res["x"].append(beta)
res["loss"].append(f(beta).item())
res["iter"].append(t)
return res
Adam - SGD
[ 2.9752 -0.0641 0.1981 0.0515 0.0914 -0.053 -0.0403 30.8483 0.0677
0.1179 0.0005 -0.0296 40.4603 -0.0836 44.6719 -0.1786 72.2796 -0.0657
-0.0142 -0.1159 -0.0386]
Adam - MBGD (25)
[ 2.9886 0.0628 0.0822 -0.0225 0.0747 -0.1306 -0.0269 30.9399 0.1265
0.0126 0.0624 0.0686 40.5911 -0.0319 44.6542 -0.0513 72.2608 0.0044
0.0609 -0.0892 0.0563]
Adam - MBGD (50)
[ 2.9901 0.0812 0.0418 -0.0795 0.039 -0.0832 0.0236 30.9293 0.1255
-0.0052 0.0245 0.014 40.6152 -0.0599 44.6557 -0.0594 72.2556 -0.0312
0.0611 -0.0655 0.1096]
Adam - GD
[ 2.7747 -1.783 -3.6167 4.6093 2.4556 0.8484 0.0964 22.7301 -0.3594
1.9632 0.5982 2.0708 23.4988 -0.3894 23.818 0.4661 24.1671 -0.9899
-0.1128 -1.777 -4.0147]
Sta 663 - Spring 2025