Fundamentals of machine learning and training neural networks
It’s an exciting time for machine learning and AI. Models such as ChatGPT and it’s ilk appear nothing short of magical. The natural question arises — how is any of this possible?
The answer essentially boils down to two things:
- Iterative mathematical methods to “follow” a function’s slope towards (essentially Newton’s method for solving functions)
- A LOT of computational power
ChatGPT is essentially a very complicated mathematical function that takes a text input, and tries to predict to next blob of text in the sequence.
I’ve had so many wonderful discussions with people who want to learn, but a common thread that comes up is the intimidation of the maths required to understand it.
I’m here to say that the good news is: you don’t need that much maths to get into it! Honestly, the main thing you need is a solid grasp of the concepts, rather than the inner calculations.
But how do we arrive to that conceptual understanding? I want to help people
make that connection and de-mystify the calculus that happens under the hood,
and show ho to do machine learning (with the widely used pytorch
library).
I’ve written a
tutorial
on basic pytorch
before, but today I want to take a step back and focus more
on connecting the code with the underlying theory and concepts.
In a previous job, I worked with David Norrish, who was kind enough to assemble a series of notebooks explaining convolutional neural networks from first principles. I’ll be drawing on his work for this explanation, as well as the following Pytorch tutorials:
-
https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html
-
https://pytorch.org/tutorials/beginner/introyt/autograft_tutorial.html
When training a model, we want it to get better at its task over time. In PyTorch, this is done by:
-
Measuring how wrong the model is (with a loss function), and
-
Adjusting the model’s weights (with an optimizer) to reduce that error.
Today we will
- Get a brief overview of the ideas from calculus we’ll need
- Establish a simple problem of fitting a line through a set of points on a 2d plane
- Define how to measure how wrong the model is
- Connect the concepts with the code used to fit this model
Demystifying the calculus
At the heart of training a neural network is calculus — specifically, gradients, and how they tell us which direction to move the models weights to make better predictions.
The core idea is this:
A gradient measures how much a change in one variable (like a weight in your model) will affect the output (like the loss).
We want to minimize the loss, so we compute the gradient of the loss with respect to each parameter. This tells us how to change each parameter to decrease the loss.
To compute these gradients, we use the chain rule from calculus. The chain rule says:
the gradient of z relative to y and that of y relative to x allows one to calculate the gradient of z relative to x as the product of the two gradients. Wikipedia
Neural networks are just composed functions — data flows through layers of computation. So when we compute the gradient of the loss with respect to a weight deep in the model, we’re chaining together many of these derivative computations.
Doing this by hand is painful. Fortunately the autograd
engine in PyTorch
automatically tracks operations on tensors, builds a “computation graph”, and
applies the chain rule backward through that graph when you call the
backward()
method on a tensor.
This method will fill in the grad
attribute on each parameter with the right
derivative.
Great, so we have the gradients… What do we do with them?
There’s a variety of optimization methods out there, but we won’t go into detail with any of them today — we’ll just use the helpful abstractions in PyTorch.
Step 0: Establish the problem
Let’s say that we have a dataset of a collection of points on the 2d plane,
represented by {(x_0, y_0), (x_1, y_1), (x_2, y_2), ..., (x_N, y_N)}
. We want
to find a model, given by a function of x
that outputs y
, that best matches
this dataset. Keeping with the ChatGPT analogy, let’s pre=tend that x
represents some blob of text, and y
is the next blob of text.
The simplest neural network we can use for this problem is a linear regression
f(x) = W * x + b
, where
f(x)
is our model (“neural network”)x
is the input to the modelW
is a 1 x 1 matrix (1 parameters)b
is a 1-dimensional vector (1 parameters)
So we have a dataset, and we have a model — and we want to find the “best”
values for W
and b
such that f(x)
has the best “fit” to this dataset.
What does “fit” mean in this context?
We can measure how good of a fit we have by calculating the distance between our
model and the points {x_0, x_1, ...}
, squaring the distances, then adding them
all together: called the mean squared error loss (mse)
1L = sum([(x_i - f(x_i))^2 for x_i in x])
The lower this value L
is, the better our model is.
So how do we find the best f
that gets the lowest possible L
?
- “guess” what the parameters of
f
should be - calculate
L
, - compute the gradients with respect to the parameter values,
- That it, we’ll calculate the gradient of
L
with respect toW
and the gradient ofL
with respect tob
- That it, we’ll calculate the gradient of
- adjust the parameters to follow the gradients of
L
down, so that we can pick new values ofW
andb
that achieve a lower value ofL
- repeat!
After several iterations of this, our model f
will improve, and “learn”.
Conceptually, that’s all there is to the fundamentals of machine learning.
Step 1: Track Gradients with autograd
PyTorch automatically keeps track of operations on tensors that require
gradients. This means if you define a model weight with requires_grad=True
,
PyTorch will remember how it was used and compute its gradient when you call the
backward()
method on the output of the calculations.
1import torch
2
3# Some example inputs and model weights
4x = torch.randn(5) # Input vector
5y = torch.randn(5) # Target output
6
7w = torch.randn(1, requires_grad=True) # Model weights
8b = torch.randn(1, requires_grad=True) # Bias term
9
10# Model output
11z = x * w + b # Linear layer
12
13# Loss: mean squared error
14loss = torch.nn.functional.mse_loss(z, y)
15
16# Compute gradients
17loss.backward()
18
19# Now w.grad and b.grad contain the gradients
Step 2: Update Weights with an Optimizer
After computing gradients, we use an optimizer to update the model’s weights. The Adam optimizer is commonly used because it adjusts learning rates automatically for each parameter and generally works well out of the box.
We’ll pick a learning rate (lr) of 0.001
to adjust the magnitude of these
steps. There’s not really a hard science for choosing a good learning rate aside
from experimentation, but we want to err on the side of “small, but not too
small”.
1# Create the optimizer
2optimizer = torch.optim.Adam([w, b], lr=0.001)
3
4# Apply the gradient update
5optimizer.step()
6
7# Reset the tracked gradients
8optimizer.zero_grad()
Important: Before calling the
backward()
method each time, we shouldn’t forget to reset the gradients withoptimizer.zero_grad()
.Otherwise, PyTorch will accumulate gradients across steps, and the model won’t fit correctly!
Putting It Together
A basic training step looks like this:
1# Forward pass
2z = x * w + b
3loss = torch.nn.functional.mse_loss(z, y)
4
5# Backward pass
6loss.backward() # Compute new gradients
7optimizer.step() # Update weights
8optimizer.zero_grad() # Clear old gradients
And that’s all there is to it! We can then put this into a for
loop for a
fixed number of iterations, and we’ll end up with a pretty good model for our
dataset.
There’s many more techniques and tricks used in the machine learning practice,
but it all comes back to these fundamentals. With this understanding, we can
swap f(x) = W * x + b
for more complicated models, and the same recipe will
give us a good fit for our data.
As I’ve show, the maths is not the hard part of machine learning. The hard parts are more often:
- Working with messy real world data
- Selecting an appropriate model (not too big, not too small)
- Selecting hyperparameters such as the learning rate
As we get more comfortable with these concepts, we’ll come back to this topic and expand on these difficulties in future posts.