Lecture 06
We will not spend much time on this as most data you will encounter is more likely to be a tabular format (e.g. data frame) and tools like Pandas are more appropriate.
For basic saving and loading of NumPy arrays there are the save()
and load()
functions which use a custom binary format.
Additional functions for saving (savez()
, savez_compressed()
, savetxt()
) exist for saving multiple arrays or saving a text representation of an array.
While not particularly recommended, if you need to read delimited (csv, tsv, etc.) data into a NumPy array you can use genfromtxt()
,
with open("data/mtcars.csv") as file:
mtcars = np.genfromtxt(file, delimiter=",", skip_header=True)
mtcars
array([[ 6. , 160. , 110. , 3.9 , 2.62 ,
16.46 , 0. , 1. , 4. , 4. ],
[ 6. , 160. , 110. , 3.9 , 2.875,
17.02 , 0. , 1. , 4. , 4. ],
[ 4. , 108. , 93. , 3.85 , 2.32 ,
18.61 , 1. , 1. , 4. , 1. ],
[ 6. , 258. , 110. , 3.08 , 3.215,
19.44 , 1. , 0. , 3. , 1. ],
[ 8. , 360. , 175. , 3.15 , 3.44 ,
17.02 , 0. , 0. , 3. , 2. ],
[ 6. , 225. , 105. , 2.76 , 3.46 ,
20.22 , 1. , 0. , 3. , 1. ],
[ 8. , 360. , 245. , 3.21 , 3.57 ,
15.84 , 0. , 0. , 3. , 4. ],
[ 4. , 146.7 , 62. , 3.69 , 3.19 ,
20. , 1. , 0. , 4. , 2. ],
[ 4. , 140.8 , 95. , 3.92 , 3.15 ,
22.9 , 1. , 0. , 4. , 2. ],
[ 6. , 167.6 , 123. , 3.92 , 3.44 ,
18.3 , 1. , 0. , 4. , 4. ],
[ 6. , 167.6 , 123. , 3.92 , 3.44 ,
18.9 , 1. , 0. , 4. , 4. ],
[ 8. , 275.8 , 180. , 3.07 , 4.07 ,
17.4 , 0. , 0. , 3. , 3. ],
[ 8. , 275.8 , 180. , 3.07 , 3.73 ,
17.6 , 0. , 0. , 3. , 3. ],
[ 8. , 275.8 , 180. , 3.07 , 3.78 ,
18. , 0. , 0. , 3. , 3. ],
[ 8. , 472. , 205. , 2.93 , 5.25 ,
17.98 , 0. , 0. , 3. , 4. ],
[ 8. , 460. , 215. , 3. , 5.424,
17.82 , 0. , 0. , 3. , 4. ],
[ 8. , 440. , 230. , 3.23 , 5.345,
17.42 , 0. , 0. , 3. , 4. ],
[ 4. , 78.7 , 66. , 4.08 , 2.2 ,
19.47 , 1. , 1. , 4. , 1. ],
[ 4. , 75.7 , 52. , 4.93 , 1.615,
18.52 , 1. , 1. , 4. , 2. ],
[ 4. , 71.1 , 65. , 4.22 , 1.835,
19.9 , 1. , 1. , 4. , 1. ],
[ 4. , 120.1 , 97. , 3.7 , 2.465,
20.01 , 1. , 0. , 3. , 1. ],
[ 8. , 318. , 150. , 2.76 , 3.52 ,
16.87 , 0. , 0. , 3. , 2. ],
[ 8. , 304. , 150. , 3.15 , 3.435,
17.3 , 0. , 0. , 3. , 2. ],
[ 8. , 350. , 245. , 3.73 , 3.84 ,
15.41 , 0. , 0. , 3. , 4. ],
[ 8. , 400. , 175. , 3.08 , 3.845,
17.05 , 0. , 0. , 3. , 2. ],
[ 4. , 79. , 66. , 4.08 , 1.935,
18.9 , 1. , 1. , 4. , 1. ],
[ 4. , 120.3 , 91. , 4.43 , 2.14 ,
16.7 , 0. , 1. , 5. , 2. ],
[ 4. , 95.1 , 113. , 3.77 , 1.513,
16.9 , 1. , 1. , 5. , 2. ],
[ 8. , 351. , 264. , 4.22 , 3.17 ,
14.5 , 0. , 1. , 5. , 4. ],
[ 6. , 145. , 175. , 3.62 , 2.77 ,
15.5 , 0. , 1. , 5. , 6. ],
[ 8. , 301. , 335. , 3.54 , 3.57 ,
14.6 , 0. , 1. , 5. , 8. ],
[ 4. , 121. , 109. , 4.11 , 2.78 ,
18.6 , 1. , 1. , 4. , 2. ]])
This is an approach for deciding how to generalize operations between arrays with differing shapes.
Using broadcasts can be more efficient as it does not copy the broadcast data,
When operating on two arrays, NumPy compares their shapes element-wise. It starts with the trailing (i.e. rightmost) dimensions and works its way left. Two dimensions are compatible when
they are equal, or
one of them is 1
If these conditions are not met, a
ValueError: operands could not be broadcast together
exception is thrown, indicating that the arrays have incompatible shapes. The size of the resulting array is the size that is not 1 along each axis of the inputs.
Why does the code on the left work but not the code on the right?
x (2d array): 4 x 3
y (1d array): 3
----------------------
x+y (2d array): 4 x 3
x (2d array): 3 x 4
y (1d array): 3
----------------------
x+y (2d array): Error
x (2d array): 3 x 4
y (2d array): 3 x 1
----------------------
x+y (2d array): 3 x 4
For each of the following combinations determine what the resulting dimension will be using broadcasting
A [128 x 128 x 3] + B [3]
A [8 x 1 x 6 x 1] + B [7 x 1 x 5]
A [2 x 1] + B [8 x 4 x 3]
A [3 x 1] + B [15 x 3 x 5]
A [3] + B [4]
Below we generate a data set with 3 columns of random normal values. Each column has a different mean and standard deviation which we can check with mean()
and std()
.
Lets use broadcasting to standardize all three columns to have mean 0 and standard deviation 1.
In addition to arithmetic operators, broadcasting can be used with assignment via array indexing,
JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.
JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.
JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.
array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
1. , 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9,
2. , 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9,
3. , 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9,
4. , 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9,
5. ])
Array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
1. , 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9,
2. , 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9,
3. , 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9,
4. , 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9,
5. ], dtype=float32)
Array([ 0. , 0.19867, 0.38942, 0.56464, 0.71736,
0.84147, 0.93204, 0.98545, 0.99957, 0.97385,
0.9093 , 0.8085 , 0.67546, 0.5155 , 0.33499,
0.14112, -0.05837, -0.25554, -0.44252, -0.61186,
-0.7568 , -0.87158, -0.9516 , -0.99369, -0.99616,
-0.95892, -0.88345, -0.77276, -0.63127, -0.4646 ,
-0.27942, -0.08309, 0.11655, 0.31154, 0.49411,
0.65699, 0.79367, 0.89871, 0.96792, 0.99854,
0.98936, 0.94073, 0.8546 , 0.7344 , 0.58492,
0.41212, 0.22289, 0.02478, -0.17433, -0.36648,
-0.54402], dtype=float32)
As we’ve just seen a JAX array is very similar to a numpy array but there are some important differences.
The default JAX array dtypes 32 bits not 64 bits (i.e. float32
not float64
and int32
not int64
)
UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([1., 2., 3.], dtype=float32)
64-bit dtypes can be enabled by setting jax_enable_x64=True
in the JAX configuration.
Using JAX interactively allows for the use of standard Python control flow (if
, while
, for
, etc.) but this is not supported for some of JAX’s more advanced operations (e.g. jit
and grad
)
There are replacements for most of these constructs in JAX, but they are beyond the scope of today.
Pseudo random number generation in JAX is a bit different than with NumPy - the latter depends on a global state that is updated each time a random function is called.
NumPy’s PRNG guarantees something called sequential equivalence which amounts to sampling N numbers sequentially is the same as sampling N numbers at once (e.g. a vector of length N).
Sequential equivalence can be problematic in when using parallelization across multiple devices, consider the following code:
How do we guarantee that we get consistent results if we don’t know the order that bar()
and baz()
will run?
JAX makes use of random keys which are just a fancier version of random seeds - all of JAX’s random functions require a key as their first argument.
Since a key is essentially a seed we do not want to reuse them (unless we want an identical output). Therefore to generate multiple different PRN we can split a key to deterministically generate two (or more) new keys.
1.09 ms ± 92.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
3.42 ms ± 122 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
When it works the jit tool is fantastic, but it does have a number of limitations,
Must use pure functions (no side effects)
Must primarily use JAX functions
jnp.minimum()
not np.minimum()
or min()
Must generally avoid conditionals / control flow
Issues around concrete values when tracing (static values)
Check performance - there are not always gains + there is the initial cost of compilation
The grad()
function takes a numerical function, returning a scalar, and returns a function for calculating the gradient of that function.
vmap()
I would like to plot h()
and jax.grad(h)()
- lets see what happens,
TypeError: Gradient only defined for scalar-output functions. Output had shape: (101,).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,
1. ], dtype=float64)
Sta 663 - Spring 2025