VRAM scales quadratically as sequence length increases. I'm not aware of any solutions. Even efficient implementations of long context fine tuning such as LongLoRA only improve speed and quality, but leave memory usage the same as LoRA.
I recommend ensuring you're reducing memory in other ways:
-
Ensure you're using 4-bit QLoRA
-
Ensure batch size is 1
-
Ensure you're using FlashAttention-2
-
Ensure your optimizer state in in CPU memory by utilizing a paged optimizer.
-
Use gradient checkpointing.
You also could do something more experimental like employ Mistral with a sliding window of 1024 tokens to capture 2048 tokens of context while only using the memory of 1024 tokens.
Or you could just summarize or prune your long examples.