lightSpeedBrick

joined 11 months ago
[–] lightSpeedBrick@alien.top 1 points 11 months ago

Ah, I hadn’t thought of that. I’ll look into it. Thank you for the suggestion!

 

TL;DR: Why does GPU memory usage spike during gradient update step (can't account for 10gbs) but then drop down?

I've been working on fine-tuning some of the larger LMs available on HuggingFace (e.g. Falcon40B and Llama-2-70B) and so far all my estimates for memory requirements don't add up. I have access to 4 A100-80gb GPUs and was fairly confident that I should have enough RAM to fine-tune Falcon40B with LoRA but I keep getting CUDA OOMs errors. I have figured out ways to get things running, but this made me realize I don't really understand how memory is allocated during training.

Here's my understanding of where memory goes when you want to train a model:

Setting

-> Defining a TOTAL_MEMORY = 0 (MB) and I will update it as I move through each step that adds memory.

-> Checking memory usage by "watching" nvidia-smi with a refresh every 2 seconds.

-> Model is loaded in fp16

-> Using Falcon7B with ~7B parameters (it's like 6.9 but close enough)

-> Running on single A100-80gb GPU in a jupyter notebook

Loading The Model:

  • CUDA Kernels for torch and so on (on my machine I'm seeing about 900mb per GPU). TOTAL_MEMORY + 900 -> TOTAL_MEMORY=900
  • Model weights (duh). Say you have a 7B parameter model loaded in using float16, then you are looking at 2 bytes * 7B parameters = 14B bytes. ~= 14gb of GPU VRAM. TOTAL_MEMORY + 14_000 -> TOTAL_MEMORY=15_000 (rounding)

with that the model should load on a single GPU.

Training (I am emulating a single forward and backward step by running each part separately)

  • The data. I am passing in a single small batch of a dummy input (random ints) so I will assume this does not add a substantial contribution to the memory usage.
  • Forward pass. For some reason memory jumps by about 1000mb. Perhaps this is due to cached intermediate activations? Though I feel like that should be way larger. TOTAL_MEMORY + 1_000 -> TOTAL_MEMORY = 16_000.
  • Compute the cross-entropy loss. The loss tensor will utilize some memory, but that doesn't seem to be a very high number, so I assume it does not contribute.
  • Computing gradients with respect to parameters by calling `loss.backwards()`. This results in a substantial memory spike (goes up by 15_000 MB). I imagine this is a result of storing a gradient values for every parameter in the model? TOTAL_MEMORY + 15_000 -> TOTAL_MEMORY = 30_000
  • Updating model parameters by calling `optimizer.step()`. This results in yet another memory spike, where GPU memory usage goes up more than 38_000MB. Not really sure why. My best guess is that this is where AdamW starts storing 2 x momentum value for each parameter. If we do the math (assuming optimizer state values are in fp16) ----> 2 bytes * 2 states * 7B = 28B bytes ~= 28gb. TOTAL_MEMORY + 38_000 -> TOTAL_MEMORY = 68_000

LoRA would reduce this number, by dropping the amount needed during the optimizer step, but I have not yet done any tests on that so don't have any numbers.

I believe that's all the major components.

So where do the extra 10gb come from? Maybe it's one of those "torch reserved that memory but isn't actually using it". So I check by inspecting the output of `torch.cuda.memory_allocated` and `torch.cuda.max_memory_allocated` and perhaps there's something there.

memory allocated (after backward step): 53gb

max memory allocated: 66gb

Meaning at some point, an extra 13 gb were needed, but then were freed up.

My question for you folks, does anybody know where those extra 10GBs I am not finding in my math are coming from? What happens that 13GBs are freed up after the backward pass? Are there any additional steps that require memory that I missed?

This has been bothering me for a while and I'd love to get a better sense so any expert input, resources or other suggestions you may have will be greatly appreciated!

Edit: I also know that when you train with the `Trainer` class you can enable gradient checkpointing, to reduce memory usage by recomputing some of the intermediate activations during the backward pass. So which part of the whole process would this reduce memory usage at?

[–] lightSpeedBrick@alien.top 1 points 11 months ago (2 children)

My understanding is that with LoRA you reduce the number of trainable parameters and therefore the memory needed to track optimizer states (e.g for Adam that tracks 2 state parameters for each model parameter). This means that you need far less RAM to fine-tune the model. Imagine 70B parameters * 4 bytes for fp32 training plus 70B * 8bytes for Adam. Lora reduces that second part to say 1% of 70B * 8 bytes.

You can also use gradient checkpointing, which isn’t specific to LoRA, to reduce memory consumption at the expense of training time. Here you recompute activations during back-prop and cache some intermediate activations.

Can you explain what you mean by “caching intermediate gradients during backprop”? I’m not familiar with what that is.

[–] lightSpeedBrick@alien.top 1 points 11 months ago

Oh, don’t get me wrong, the dominant sentiment on r/singularity is not for me and I am no fan of the reverence certain public figures get from members of that community. I was going for polite understatement with my comment, but perhaps failed 😅

[–] lightSpeedBrick@alien.top 1 points 11 months ago (6 children)

What’s wrong with r/singularity? Folks over there are optimistic, perhaps a little too eager and optimistic. In fact most opinions that aren’t optimistic get downvoted pretty quickly.

[–] lightSpeedBrick@alien.top 1 points 11 months ago

I threaten to quit too. I don’t work at OpenAI, but I’ll quit my job and happily accept Microsoft’s offer in solidarity.

[–] lightSpeedBrick@alien.top 1 points 11 months ago

Largely unrelated, but this has a similar vibe. I wonder what happened to that high school kid who invented the transformer even before Vaswani et al, and then a year later another guy who claimed to invent a brand new neural network architecture that was supposed to break the internet.