Transformer Memory Requirements

Categories: ML

Written on July 22, 2022

Working out how much memory it takes to train a Transformer GPT2 Model.


There has been recent discussion on StackOverflow and Twitter on the full memory requirements of training a Transformer.

Because I am in the process of training Transformers myself and scaling to multiple GPUs, I became interested in this question myself. Misha Laskin provides some back of the envelope calculations for why batch size and sequence length dominate over model size that are interesting but off by approximately 4x for the model parameters and 2x for activations.

I have thrown together a more detailed calculator as a Colab notebook here. And outline my reasoning below. I’ve tested this on the “small” 124M parameter and “medium” 345M parameter GPT2 models and get close to the real values.

Here is the GPT2 model architecture (image taken from my paper):

See the Illustrated GPT2 for a full explanation of how GPT2 works.

Here are the equations and notation in full:

L = 12 # number of blocks
N = 12 # number of heads
E = 768 # embedding dimension
B = 8 # batch size
T = 1024 # sequence length
TOKS = 50257 # number of tokens in the vocab
param_bytes = 4 # float32 uses 4 bytes
bytes_to_gigs = 1_000_000_000 # 1 billion bytes in a gigabyte

model_params = (TOKS*E)+ L*( 4*E**2 + 2*E*4*E + 4*E)
act_params = B*T*(2*TOKS+L*(14*E + N*T ))
backprop_model_params = 3*model_params
backprop_act_params = act_params
total_params = model_params+act_params+backprop_model_params+backprop_act_params=4*model_params+2*act_params
gigabytes_used = total_params*param_bytes/bytes_to_gigs

For the “small” GPT2 model with 124M parameters (that uses the above values for each parameter) we get:

model_params = 123,568,896
act_params = 3,088,334,848
gigabytes_used = 26.6 Gb

While running the Hugging Face GPT2 we get 27.5Gb.

If our batch size is 1 then we undershoot again where memory is predicted to be 5.1Gb but in reality it is 6.1Gb.

For the medium sized 345M parameter model and a batch size of 1 our equation predicts that there it will use 12.5Gb while empirically it is: 13.4Gb. The 1Gb gap remains.

Let me know if you can think of where the missing 1Gb is coming from. There may be additional tensors in memory from for example by not doing in place operations for the activation function? It may also be due to memory used from concatenating the Attention head outputs or splitting them up?

The model parameter equation comes from:

(TOKS*E) [embedding layer ]+ L [number of blocks]*( 4*E**2 [Q,K,V matrices and the linear projection after Attention] + 2*E*4*E [the MLP layer that projects up to 4*E hidden neurons and then back down again] + 4*E [Two layer norms and their scale and bias terms])

Where we ignore the bias terms and positional embedding.

The activation parameter equation comes from:

B[batch]*T[seq. length]*(2*TOKS [one hot vectors at input and output]+L[number of blocks]*(3E [K,Q,V projections] + N*T [Attention Heads softmax weightings] + E [value vector] + E [linear projection] + E [residual connection] + E [LayerNorm] +4E [MLP activation]+E [MLP projection down]+E[residual]+E[LayerNorm] ))

Aside from the 1Gb gap in my memory calculations, when I turn on floating point 16 the memory for 1 batch only drops from 6.1Gb to 5.8Gb. Meanwhile for a model with a batch of 8, it goes from 27.5 to 21.8Gb. Why are there not larger memory savings? Is this because it is mixed precision and the model decides that it needs high precision for many of its components?


Thanks to Miles Turpin and Misha Laskin for motivating this piece. All remaining errors are mine and mine alone.