Stanford CME295 Transformers & LLMs | Autumn 2025 | Lecture 4 - LLM Training
Cool.
Hello everyone and welcome to lecture 4
of CME 295. So today is Friday, October
the 17th, which means that the midterm
is one week away. So before we start,
I'm just going to go over some logistics
to make sure you know we're all aligned
on what to expect.
So the midterm will take place next
week, same time. Instead of an hour and
50 minutes, it will be an hour and 30
minutes. So it's 3:30 to 5 in this
classroom.
So it's like business as usual. Um in
terms of topics
the midterm will be about lectures 1,
two and three which we had and this one
which is lecture four.
So just to give you like an overview of
what you can expect in the midterm.
There's going to be uh some multi-choice
questions along with some free form
questions but they're mainly going to be
about the things that we've seen in
class. So if you watch the recordings or
attend the lectures and you know just go
through the slides and know the
important formulas I think yeah you'll
be you'll be fine.
So I know that you may have questions
until uh next week. So that's why after
this lecture with Shervin we will be
holding office hours. So feel free to
you know come to us and ask us any
questions. And uh of course we'll be
fully available between now and next
week. So in case you have any questions,
feel free to uh ping us on Ed. Um and
yeah, we'll make sure to respond.
Um cool. Uh I also know that a number of
you are auditing this class. So in case
you're still interested to take the
midterm for some reason uh maybe uh
because you have an upcoming interview
uh just uh tell us so that we can just
expect the number of uh like copies to
print. So we'll be printing this on
Monday. So just let us know over the
weekend in case you're interested.
Cool. So that's for the midterm. And
then the second piece of news is the
final. So, we said we were working on
the dates. So, we finally finalized the
dates, which did not change. So, it's
Wednesday, December the 10th. Okay. So,
a little bit late, 7:00 p.m. to 8:30
p.m. Uh, so it's a slot that we have.
Uh, the location is different from this
one. So, it's in this room.
And the final will only cover the second
part of the class which is basically
lectures five to 9.
Any questions on this? Yeah.
Oh yeah, good question. So is it closed
notes? Yes. Yeah. Yes.
So question is what is the format of the
multiple choice? So you'll have uh so we
did not finish writing the exam, but
it's going to be something like you have
a question and you let's say you have
like three four possible answers and
then you just choose the the one that's
that's correct. Something like this. And
you'll also have some free form uh like
you'll have to just like answer in your
own words. Yeah.
>> Yeah. Uh thanks. So question is are we
allowed to take anything? So it's closed
book. So like
like yeah nothing just a pen.
Yeah.
Uh question is no calculator you will
not need calculator
but speaking of the cheat sheet. So I'm
not sure if we mentioned I think we did.
So there is a cheat sheet for this one
which we cannot bring to the exam but
you can use for your uh just for uh you
know just your studying. Uh that's on
the website class website.
suggest recommends looking at it.
Cool. So, super clear for everyone.
Very cool. Well, okay.
As always, we'll be starting the class
just recapping what we saw in the
previous lecture. Um so if you remember
we
basically
studied a new kind of architecture which
was called the mixture of experts uh
which is such that if you have an input
what you want is to not necessarily
activate all the parameters and so you
are in a setting where you have multiple
experts uh and in the forward pass you
only activate some of them so that's a
sparse MOE. You also have the dense MOE
which basically weights the outputs as a
function of the output of the gate. So
we saw that this architecture was used
in LLM
and it was mainly used to be able to
scale these LLMs without incurring an
expensive cost at inference time because
you don't want to activate all the
parameters.
The second thing that we saw was uh just
defining what an LLM was and in
particular how you could
decide on what the next token prediction
is. So we saw three methods. First one
was we called uh greedy decoding which
was always taking the highest probable
token.
The second method we saw was beam search
where we kept track of the k most
probable sequences.
And then the third one was sampling.
So we're not doing a most probable we're
not keeping track of the highest
probable sequences. What we do is we
sample the next token respect to the
distribution that we get as output. And
then we saw there's this hyperparameter
that's called temperature that allows
you to tweak how spiky you want your
distribution to like versus uh not.
And we also saw some inference
optimization techniques which are used
in practice to avoid having uh like a
big cost at decoding time. So I'm not
going to just mention everything but I
would say KV cache for instance is a is
an important method. So yeah just
recommend just knowing what it is along
with the other ones.
And with that we're going to start
lecture four and actually I was really
looking forward to today because lecture
one we saw what self attention was what
a transformer was. Second lecture, we
saw
some of the tricks that people use today
and some of the variations from the
transformer. We introduced what an LLM
was last lecture and this lecture we're
finally going to see how these LLMs are
trained. So today we're going to focus
on LLM training.
And the first thing that I'm going to
say is if you've been in the ML field
for let's say more than a few years now
uh you may have noticed that
traditionally
if you had a task what you would do is
train a model specifically for that
task.
So let's suppose like 10 years ago,
let's suppose we had a task which was
around detecting spam. You would train a
model specifically to detect spams. So
you would train on the training set,
eval on the validation set and then test
on the test set. If you had another use
case that suppose sentiment extraction,
you would train a model specifically for
that
and so on and so forth.
But one could argue that these tasks,
they're not completely disjoint.
They're all involving just understanding
the text. So one could argue we could
find a way to somehow leverage the
knowledge that we acquired during
training for let's say one task
and reuse that for another task.
So this method has a name. It's been
around for some time. It's called
transfer learning.
So the goal of transfer learning is to
not always start from scratch. If you
have a new task, it's to start with some
pre-trained model. And we're going to
see what pre-train is and then tune it
for your task instead of starting from
scratch.
Well, it's basically the paradigm on
which LLMs are trained. So the idea here
is that all these tasks, they involve
understanding language. So, what we're
going to do is have what we call a
pre-training stage, which involves
training your LLM on vast amounts of
data to just understand what language,
what code is
and then have a second stage of quote
unquote tuning.
And we're going to see a little bit what
that tuning is. But in that second
stage, we're going to take our
pre-trained model and somehow find a way
to tune the weights to adapt to a
specific task.
So as an example here uh we would
pre-train a huge model and then suppose
for spam detection we would somehow tune
it for that uh sentiment instruction
same we would tune it for that and so
on.
So this is just to take the example from
before. And the idea here is in order to
obtain these models, we're not going to
start from scratch.
Cool. So okay. So now we're going to see
what pre-training is. So pre-training is
by far the most expensive both in terms
of compute cost, you know, everything
part of the training.
So what it does is taking a huge amount
of data and training your LLM to just
predict the next token.
And here by data what I mean is
basically everything you can find. So it
can be uh you know text in English, can
be text in other languages, it can be
even codes,
can be code in different languages, can
be basically the whole internet.
We're going to see some of the data sets
that people use for that. But you can
think of this as just training your
model to try to predict anything that's
written.
And as I mentioned, the objective here
is to predict the next token. So if you
remember, our LLM is a texttoext model
and most likely a decoder only model in
more than 90% of the cases. So what it
does is it takes some input text and it
tries to always predict the next token
in an iterative basis.
So in terms of the data sets that are
used, you will see the term common crawl
a lot on papers, it's basically a data
set composed of anything you can find on
the internet. So I think they have
something like three billion pages per
month. So if you go on their website,
they have a hu huge archive. So there's
a bunch of other websites as well that
you can find in there. So for instance
the Wikipedia articles any like social
media as well like Reddit I know there
are a lot of Reddit conversations in
those in those data sets you have a lot
of codes and of course you have a bunch
of places for that you have GitHub you
have stack overflow all these like
forums that talk about codes so all of
this is meant for your model to just
understand the structure of the language
and code
and in terms of size so it's measured in
terms of token number of tokens and one
order of magnitude that I want you to
remember is is on the order of hundreds
of billions or even trillions or even
tens of trillions of tokens.
So I'll give you an example. So GPT3 was
trained on 300 billion tokens and for
instance Llama 3 which was I believe
published last year was trained on 15
trillion tokens.
So these are huge data sets.
So before we go further, I want to
introduce two notations and I think one
of them I introduced I introduced it
last lecture.
The reason why I want to talk to you
about these notations is they are used
everywhere to talk about how much
compute
uh some model needs. So the first
notation is flops
which stands for floating operations
and what it is is it's a unit of
compute. So the higher the flops the
more operations are involved because by
definition flops is the number of
operations that involve floatingoint
numbers. So floatingoint numbers you can
think of them as just like numbers with
decimal points.
So in terms of order of magnitude
training an LLM is on the order of 10 ^
of 25 flops
and the way you obtain flops. So usually
it's like a complicated formula but in
your mind you can think of it as
something that is a function of the size
of your data. So the number of tokens
that you train it on and the number of
parameters of your model.
So there's not like a universal formula
because it also is a function of the
architecture. So you can think of for
instance based LLMs as requiring let's
say less compute because only some parts
are activated compared to let's say
dense LLMs.
But you can just think of it as it's a
function of the number of tokens and
parameters. It's like O of the product
between the two more or less.
And then there's a second notation that
I want to introduce which is also flops
but it's different. So here flops stands
for floating point operations per
second.
So it's a measure of compute speeds. So
it's basically how fast can your
hardware
execute these operations
and so you also have like some order of
magnitudes here. uh but if you're into
uh let's say GPUs, you will see that in
the description of GPUs, they always
indicate flops and we will see that in a
second.
But I just want to call out that flops
here usually all caps.
Although you may see some papers that
use one for the other,
which is confusing. So I just recommend
uh just contextual contextualizing
this notation with respect to the
sentence that it is in because sometimes
people actually switch the two but this
is the common notation.
So far so good.
Cool.
Okay. So now we know that we have a
pre-training step. We know it involves a
lot of compute. We know it involves a
lot of data. We know that our model is
large. So what people did was trying to
see how the performance evolves as a
function of model size and training
size.
And there's this one paper called
scaling loss for neural language models
that was published in 2020 that
performed a bunch of experiments by
varying these parameters. And what they
found was the more compute you have, the
better your model learns about
predicting the next token. Same for data
set uh size. So the more the bigger your
training set, the better it is. And the
bigger your model, the better it is.
So for some time, I think between 2019
and 2024,
you were seeing models that were larger
and larger, just people just building
things that were bigger and bigger
because according to these experiments,
um the performance was just getting
better.
So something else that they noticed was
bigger models tend to be more what they
call sample efficient.
So what that means is for an equal
amount of tokens that is processed
you will have a better performance with
a bigger model compared to a smaller
one.
But then you can wonder you know um we
don't have unlimited compute you know
compute is expensive it you know it has
a lot of drawbacks. So uh you have a
fixed compute and people also try to
answer the question given a certain
amount of compute.
How can you fix your training set size
and your model size in a way that's more
optimal?
Cuz um here uh you need to decide how
big is your model. So what they did is
they fixed a unit of compute which is
the color of these curves
and they tried training models of
different sizes with different training
set size. And what they saw was that
there was always a sweet spot here
which followed some kind of
relationship.
And in particular, this is a table that
summarizes quote unquote the optimal set
of number of parameters and training set
size, which is sometimes called the
chinchila.
And what they realized was if you have
an amount of training set size that's
about 20 times
the model size then you're spending your
compute in quote unquote like an optimal
way. And in particular,
GPT3 for instance,
I think it was like 175
billion parameters if I remember
correctly, but it's only trained on 300
billion tokens. So this one for instance
is according to this really undertrained
quote unquote.
I think there's a question. Yeah.
Um so yeah the question is do they fit
the neural architecture? So I think um
by now everyone agrees that LLMs are
transformerbased decoder only models. So
everyone uses the same model.
Yeah. So you can assume that when I say
LLM here it basically means decoder only
transformer based models.
Yeah. Question is architecture change
does not play a big role. So that's what
they say actually in their paper. They
say the thing that changes the most is
the amount of tokens on which you train
and the size of your model.
Cool. Any other questions?
Yeah.
Oh yeah, good question. So question is,
is there some kind of transfer learning
between different versions of models?
Um, so for a lot of these models,
they're actually closed source, so they
don't exactly reveal these things. But,
uh, I guess it's an interesting
question.
um one that I cannot answer in a general
way. So maybe I think it's the best
answer I can give you. Um but in any
case uh when you look at um some of
these papers they always state how much
it costs to train this and it's always
in the order of you know millions. It's
always an expensive step regardless.
Cool. Um just uh speaking of that um so
pre-training has a lot of challenges.
One of them is cost. So, uh when I say
millions of dollars, it's a minimum. I
think it can even cost tens of millions
of dollars or sometimes hundreds of
millions of dollars. It takes a lot of
time
and um people have been mindful of the
impact on in the environment. So,
they've also been including the
ecological cost.
So the other uh challenge is that the
pre-training
step is on data that is up to the time
at which you pre-train your model on. So
what that means is that the knowledge
that you acquire from training on this
data set can only go up until the date
at which you cut your data set.
So this date is called the knowledge
cutoff date. And so what that means is
your base model, your pre-trained base
model does not know has no way to know
by itself knowledge that occurred after
this state.
And speaking of that, a lot of papers
they've tried to edit knowledge, inject
knowledge. It's always tricky because uh
there's not a clean way to um you know
change the weights in a way that does
not penalize
some parts. So I guess what people want
to do is inject knowledge but not
regress in some other domains. And this
is a very hard problem. And of course
you know these models they try to
predict the next token and uh there's
this question of uh what if it just
generates something that it has seen at
training time. So what we call
plagiarism so there's always a risk. So
these are all the challenges I just want
to illustrate when I said the knowledge
cut off dates. So if you go on let's say
the open a website or Google websites to
look at the model cards you will always
see so I'm not sure if you can see from
here but um there is always a line on
knowledge cutoff dates which tells you
on when the pre-training of this model
was done. So here for instance GP5 was
released a few weeks ago and here it
says the knowledge cut of date is
September 30th. So you can guess that
they've done their pre-training around
that stage.
Cool.
Any questions on the first part?
Everyone good?
Perfect. So in this first part, we've
seen that pre-training was a crucial
step of the LLM training process and
we've seen all these big numbers
and one could wonder well how can you
train such a big model on such a big
amount of data like how do people do
that?
So this is what we're going to see here.
So just what I had mentioned so LLMs you
can think of them as decoder only
transformerbased models. So in order to
train your model you need that
you need a lot of data but then if you
look at your architecture
you see that a lot of the operations
involve matrix multiplications.
And I guess I have a question for you.
What is the kind of hardware that loves
matrix multiplications?
GPUs. Yes. So you also need GPUs.
Actually more than one. Yeah. You had a
question.
Oh. Uh question is GPUs for inference.
So this one we're going to focus on
training. But um requirements for GPUs
they differ a little bit between
training and inference. But in this
part, we're solely focused on training.
And speaking of GPUs, uh I guess uh it's
not GPUs everywhere because for
instance, Google, they've developed
their own hardware that's called TPUs.
Uh but any non Google
uh Google based models, they've most
likely been trained on GPUs.
Cool. So in order to train your model,
what do you do? So first of all you have
your LLM which is now so this is this
model but uh now we're representing with
a box just for simplicity you initialize
it uh it's like um you know lot of
parameters so you can think of uh the
scale as being somewhere around like
billions to hundreds of billions of
parameters. a huge model.
And what are the steps involved to train
a model? Well, what you're trying to do
is to tune the weights so that the model
can learn how to generate the next
token.
So, you have one step called the forward
pass where you have a bunch of data that
you're trying to pass through the
network. And um while we do that I just
want to call out things that are
important to note that we need to
somehow save in memory.
So when you do this forward pass you
have something that's called activations
which are basically the values at each
layer that are needed in order to
compute the loss.
So the loss tells you how off you are
compared to uh the label that you want
to train this on. And so the amount of
memory that you will use here is
dependent on a lot of things. It's
dependent on the mouth size which
impacts the number of activations. It's
dependent on how big your batch of data
is for training and it's dependent on
how large your context length is because
if you remember uh here we have of n
squared complexity because of this self
attention operation where n is the
sequence length. So you have all these
parameters that come into play.
So once you do the forward pass, let's
suppose you compute the loss, you know
how off you are compared to your label.
Now the next step is to somehow tweak
the weights in a way that minimizes the
loss. So how do you do that? There's
this another pass called backwards pass.
So what this pass does is quantify
the direction where the loss is going to
be minimized.
It's called a gradient. You take the
gradient of the loss with respect to
each parameter.
Well gradients they also need to be
saved somewhere in memory.
And then you have finally the weight
update
which is where you know where the
direction at which your loss is going to
be minimized. So you apply that update
to your weights and you typically use
optimizers out there like have you heard
of atom optimizer? Yeah. So atom
optimizer just a fancy version that has
some additional quantities
uh which keep track of uh which are
basically a function of the gradient. So
you have the first moment and the second
moment which is basically an average of
the moving average of the gradient and
the squared gradients
and all these quantities. So the first
moment the second moment you also need
to somehow save them somewhere in
memory.
So it's a lot of things to save.
Well,
okay, breaking news. Memory is not
unlimited. Memory is limited. And so
here what we have in front of us is the
description of a GPU.
Uh I think so. Yeah, H100, which is a
very good GPU. And you will see that in
that description there's a line on GPU
memory. So GPU memory is your amount of
memory per GPU. It's uh 80 gig for this
one. It's quite large. So it's in on the
order of tens of gigabytes.
So you need to store all these things in
80 GB
which is not a lot.
So
what are we doing? What will you be
doing?
So I guess the idea is to leverage not
one but several GPUs in order to somehow
distribute the load across CPUs. And in
order to do that you have several
methods
which we will see in a second.
So the first set of methods is called
data parallelism
also known as DP.
So what this set of methods does is it
distributes
data across GPUs
so that this forward pass and backward
pass they can all be done kind of
independently.
And so the idea here is to divide the
batch of data across devices.
And then um in order to do that of
course you need to have a copy of the
model per device
because of course you need to compute
the activations you you need to compute
all these things. Um but when you do
that you're able to reduce the memory
that is linked to the batch size.
So that's called data parallelism. Yes.
uh question is how about the gradient
updates? Well, it's a great question.
So, how what do you do when you have
independent computations here and there?
Well, the gradient is just the average
of the gradients uh for this for this
thing. So, you have some communication
in between the GPUs that basically
aggregate the gradient for for the
updates.
So, I have a question for you.
Is
this the answer to everything? Like if
we just scale up like this for I don't
know lot of GPUs is it is it is is it
great always great or do we have like
cons?
Oh yeah uh great point. So yeah you have
to fit one model so yeah that's great
point. So the second point that I will
add is you have an additional cost which
is called communication cost because you
need to somehow communicate between your
GPUs in order to aggregate some
quantities.
So your training is going to be slower.
It's good you you can scale up the
memory. Of course you need to uh fit a
model on a on a device and we will see
what how we can do to do that but you
will be incurring those communic
communication costs so it's not all you
know great
so speaking of the memory and the fact
that we want to I guess be able to at
least store a model per um per uh
device. So people have realized that
there's actually a lot of duplication
and there's been a paper on wanting to
dduplicate this duplicate information
and this method is called zero
zero redundancy optimization
and the idea is that in each on each GPU
you know you store the same parameters,
you store the same gradients, you store
the st same optim optimizer step states.
So the idea here is how about we shard
we partition those quantities across
GPUs. So the first variation is around
sharing the optimization optim optimizer
states. So meaning we partition those
states across the GPUs. So this reduces
the memory by a lot. We can also
partition the gradients
and we can also partition the
parameters.
So here we have no redundant
information. Things are just
partitioned. Well, the problem is you're
going to have even more communication
costs, but at least it allows uh for us
to decrease the memory load on each GPU.
So this is zero. So there's 0 1, 02, 03.
And I guess the variation that you will
choose will be a function of how
sensitive you are to I guess training
time and how big is your model
and whether this will be an actual
problem or are you just fine with just
storing everything.
So that's one set of methods. So this
set of methods is again data
parallelism.
So it's basically you having independent
sets of data that are handled by
different GPUs.
Well, you have another set of methods
that's called model parallelism.
So model parallelism tries to
parallelize
the operation even within one batch.
So there's a bunch of methods. I don't
want to sound too like a catalog. So
we're not go through them all by one by
one, but I will just call out a few that
are worth noting.
So if you remember last lecture we
talked about MOE based LLM and how
sequences were being sent to different
experts.
Well there is a way to distribute that
across GPUs via this expert parallelism
techniques
which is uh having let's say one expert
on a device another one on another
device.
So that's one thing worth noting.
Another one I will say so tensor
parallelism is uh when you have big
matrix multiplications
to somehow cut that in a way that
decreases the uh memory required for
that.
Okay. And maybe the last one I will say
is pipeline parallelism.
It's when
you consider a forward pass as involving
several layers. So you're going to say
that one GPU is going to only be
responsible for let's say layers 1 2 3
and then another one layer three uh
sorry four to five four sorry four five
six and so on and so forth
um so you also have that kind of
parallelism but anyways there's a bunch
of techniques and the ones that I
mentioned they fall in the bucket of
model parallelism
make sense
No need to know the details on there,
but I think just like knowing that there
are several methods and just a rough
idea, I think is a is a good good thing
to have in mind.
Cool.
So, what did we do? So, we realized that
during the training process, we had to
save a lot of things in memory. So what
we saw was techniques that reduce
the burden of having memory per GPU. So
we are trying to distribute that across
GPUs. So we saw data parallelism and
then the zero method that has some extra
optimizations and we saw model
parallelism as well.
So now we're going to see another
technique that leverages the structure
of the GPU. And you may have heard of
this technique is called flash
attention.
It was actually developed here at
Stanford uh in 2022.
And in order for me to talk to you about
this technique, I want to tell you more
about what GPU is composed of.
So if you look under the hood, well GPU
is very complicated and I'm for sure not
uh I don't know everything either, but
what I know is that we have two kinds of
memories in a GPU. So you have one kind
of memory that's big but relatively slow
that's in the HPM
and then another kind of memory that is
fast but much much smaller which is on
chip next to the where the compute
happens that's called the SRAM
so you have HPM and SRAM HPM has
something around uh you know tens of
gigab by so it's like the GPU memory
that you saw in the description.
SRAMM is much smaller. It's like
something around like several like you
know tens of megabytes let's say. So
it's much smaller but then it is like 10
times faster. So this one is uh a few
terabytes per second let's say and the
SRAM is uh tens of terabytes per second.
So it's like a noticeable difference in
speed.
So what we want is to somehow leverage
the strength of these kinds of memories
in order to speed up the attention
computation in a in an exact way. So
what do I mean by exact way? So what I
mean is we're not making any
approximations to the computation.
What we're doing is we're just
leveraging the strength of these
components and sending the computation
in a in a clever way.
So
if you remember the self attention
computation is done with this very
important uh formula. So it's softmax of
queries and the keys over some scaling
factor times v.
So this allows queries to interact with
everyone else.
Um so in matrix form you can think of
queries as being uh as having the number
of rows equal to the sequence length and
then columns to being the the dimension
of the query and then you same for key
and value. So you have this big matrix
multiplications.
So if you do it, if you do this
computation the standards, the vanilla
way, what you would do is store them in
the big but slow memory component of the
GPU.
So you would store it in the HPM.
So here is what you would do if you were
to not do any optimization. So you would
take those matrices from the big but
slow
HPM,
perform the computation
and then write it back to the HPM
and then you would read that result
again from the HPM, compute the softmax
and then write it back to the HPM
and then you would again load this plus
the value matrix multiply them and then
write them to the HPM.
See there's like a lot of read and write
to the HPN. So it's a lot of uh data
transfer
which actually becomes the bottleneck.
So a GPU is very very fast but then you
spend a lot of time just loading your
matrices from the memory.
The reason why you do that is because of
the softmax softmax operation. So do you
remember what a softmax does? So it
normalizes the quantities so that they
sum to one but it's row dependent
meaning that each row needs to sum up to
one.
So in a sense you need that computation
to happen first before you do your
softmax. Like uh if you just like look
at it like that you you would think yeah
you you need to do the whole thing
first.
Well turns out that you don't need to do
everything
at once.
And this is the core idea behind flash
attention.
So what flash attention does is it tries
to minimize the amount of read and write
from and to the HPM
and instead takes small blocks and it's
called tiling. The method is called
tiling. It takes small blocks that it
sends to the SROM so that it gets
computed from end to end before being
sent back to the HPM.
Does that make sense? So the idea is
let's send small matrices into the SRAM
so that it does the whole you know full
end to end computation and just send it
back to the HPM because we want to
minimize the amount of read and write
from the HPM.
So here's how how you would do it. So
you remember the softmax
uh computation with the query and the
key and then the value. Well, what you
would do is to cut your matrices
and then proceed step by step.
But then there's a cool trick that I
want to talk to you about which is that
you don't need to compute the whole
matrix inside a softmax
in order to achieve the whole softmax
computation
cuz if you think about it let's suppose
you have a whole matrix and then you
have like different let's say columns or
like submatrices S1 to SN well submax of
this huge matrix
is equal to this matrix where the
softmax is taken respect to each of
these submatrices
up to some scaling factor.
So this is the core trick
and if you want to be convinced of it
just look at the softmax formula it's
like exponential of something over some
quantity which is shared across the row.
So this scaling factor will just
fix this with respect to that.
So with this in mind, what we will do is
take each respective slices of these
matrices,
do the whole computation and then
populate the corresponding
uh entry in the output matrix.
So we will do that between let's say the
first slice of the query and the first
slice of the keys and the values and
then we will repeat for the other slice
until the end
and then we will repeat for the other
queries as well until the end.
So what the paper explains is how this
scaling factor is being computed. So
this one is some formula that I did not
put on the slide. So it's not necessary
for you to memorize the formula. It's
just the idea
and the idea is exactly this trick.
So once you do that,
you basically end up with only one read
from the HPM
and these like tiled quantities are
stored in the SRAM
and then they're read from the SRAM
which is very fast and then computed and
then back to the SRAMM and then at the
end in order to accumulate the results
they're being sent back to the HPM.
So, just to make sure we're clear. So,
in green is basically when it's red from
the SRAM and then in blue is from the
HPM. You have a question? Yeah.
>> Yeah. The question is do you take the
whole row or a portion of it? So you can
take a portion of it but just for
illustrative purposes. Here we take the
like just this is just for illustrative
purposes. You can think of your your
matrix as being completely uh you know
like a grid and then uh you just like
핵심 요약
LLM의 학습 파이프라인을 다룹니다. Transfer Learning 패러다임, Pre-training 데이터와 목표, Scaling Laws, 분산 학습 기법(Data/Model Parallelism, ZeRO, Flash Attention), 그리고 Mixed Precision Training과 Post-training(SFT, RLHF)까지 설명합니다.
주요 개념
Transfer Learning 패러다임 4:30
- 전통적 방식: 각 태스크마다 모델을 처음부터 학습 (spam detection, sentiment analysis 각각)
- Transfer Learning: 언어 이해라는 공통 기반을 Pre-training으로 학습 후, 태스크별로 Tuning
- Pre-training: 방대한 데이터로 언어/코드 구조 학습 (Next Token Prediction)
- Tuning: Pre-trained 모델을 특정 태스크에 맞게 조정
Pre-training 데이터 7:00
- 데이터 소스: Common Crawl(월 30억 페이지), Wikipedia, Reddit, GitHub, Stack Overflow
- 규모: 수백B ~ 수십T 토큰. GPT-3(300B), Llama 3(15T)
- 목표: Next Token Prediction을 통해 언어와 코드의 구조 학습
FLOPs vs FLOPS 11:00
- FLOPs (소문자 s): Floating-point Operations - 연산량 단위. LLM 학습은 ~10^25 FLOPs
- FLOPS (대문자 S): Floating-point Operations Per Second - 연산 속도. GPU 성능 지표
- FLOPs 추정: O(토큰 수 × 파라미터 수)
Scaling Laws 13:30
- Kaplan et al. (2020): 모델 크기, 데이터 크기, 컴퓨팅 늘리면 성능 향상
- Sample Efficiency: 큰 모델이 같은 토큰 수로 더 좋은 성능
- Chinchilla Optimal: 고정 컴퓨팅에서 최적 설정 = 데이터 토큰 수 ≈ 20 × 파라미터 수
- GPT-3의 문제: 175B 파라미터인데 300B 토큰만 학습 (Chinchilla 기준 3.5T 토큰 필요)
메모리 요구사항 20:00
- 저장 대상: 모델 파라미터, Gradients, Optimizer States, Activations
- Adam Optimizer: 파라미터당 12바이트 (FP32 weight + FP32 momentum + FP32 variance)
- 예시: 7B 모델 학습에 ~84GB 필요 (단일 GPU 초과)
Data Parallelism 28:00
- 원리: 모델을 모든 GPU에 복제, 데이터만 분할
- 과정: 각 GPU가 미니배치로 gradient 계산 → All-Reduce로 평균 → 동일하게 업데이트
- 장점: 구현 간단, 효과적인 배치 크기 증가
- 단점: 모델 전체가 각 GPU 메모리에 들어가야 함
ZeRO (Zero Redundancy Optimizer) 32:00
- ZeRO-1: Optimizer States를 GPU들이 나눠서 저장
- ZeRO-2: + Gradients도 분할
- ZeRO-3: + Parameters도 분할 (필요할 때 All-Gather)
- 효과: 메모리 사용량 대폭 감소, 더 큰 모델 학습 가능
Model Parallelism 35:00
- Tensor Parallelism: 하나의 연산(행렬곱)을 GPU들이 분할 처리
- Pipeline Parallelism: 레이어를 GPU들에 분배 (GPU 1: Layer 1-3, GPU 2: Layer 4-6)
Flash Attention 37:00
- GPU 메모리 구조: HBM(크지만 느림, ~수십GB) vs SRAM(작지만 빠름, ~수십MB)
- 병목: 기존 Attention은 HBM ↔ SRAM 간 읽기/쓰기가 많음
- 핵심 아이디어: Tiling - 작은 블록으로 나눠 SRAM에서 end-to-end 계산 후 HBM에 저장
- Softmax 트릭: softmax([S1, S2, ...]) = [α1·softmax(S1), α2·softmax(S2), ...] (스케일링 팩터로 보정)
- Recomputation: Activation 저장 대신 backward 시 재계산. 연산량은 늘지만 메모리 절약 + 실행 시간도 감소
Mixed Precision Training 50:00
- 원리: Weight는 FP32로 유지, Forward/Backward는 FP16으로 수행
- 이유: FP16에서 gradient가 너무 작으면 0으로 underflow
- Loss Scaling: Loss에 큰 수를 곱해 gradient underflow 방지, 업데이트 시 다시 나눔
- 효과: 메모리 절약 + 연산 속도 향상
Post-training 1:20:00
- SFT (Supervised Fine-Tuning): 고품질 (instruction, response) 쌍으로 학습
- RLHF (Reinforcement Learning from Human Feedback): 인간 선호도 기반 Reward Model 학습 후 PPO로 최적화
- DPO (Direct Preference Optimization): Reward Model 없이 직접 선호도 최적화
- 목적: Pre-trained 모델을 helpful, harmless, honest하게 조정
핵심 인사이트
- Chinchilla Scaling Law는 '무조건 크게'에서 '효율적으로 크게'로 패러다임 전환
- Flash Attention은 하드웨어 특성을 활용한 최적화의 좋은 예시
- Pre-training → SFT → RLHF 파이프라인이 현대 LLM의 표준