I’m been playing around with methods such as prompt tuning and LoRA, which are parameter efficient as they only fine-tune a very small fraction (that is, <1%) of all parameters.

But for both methods, you have to cache the intermediate gradients during backprop, meaning that you don’t save on GPU memory at inference (or at most a small amount of GPU memory saved, due to not having to store optimizer states for frozen layers). For instance, I’ve had LoRA reduce GPU memory footprint for my custom model from 8.5GB -> 8.1GB, which is very minimal. Fine-tuning time reduction also isn’t really a major advantage, with finetuning the same model reduced by 20ms per batch, from 210ms to 190ms.

This begs the question - what really is the practical reason for the popularity of parameter-efficient fine-tuning (e.g. prompt tuning w/ 1.6k+ citations) if it doesn’t really save on GPU memory and training time?

I can see two possible reasons (but I’m not really convinced they really explain the ‘hype’ around parameter-efficient fine tuning):

  1. The fine-tuned model checkpoint for the downstream task is very significantly reduced. For example, in prompt tuning, we only need to save the tiny trained soft prompt (~very few megabytes), rather than the entire changed model weights (~many, many GBs) on our hard disk/SSD.
    1. But from a practical point-of-view, I feel that most people suffer from a lack of compute (e.g. GPU memory) than hard disk space. In other words, it seems that training time and GPU memory consumption are more relevant concerns than saving on checkpoint storage space.
  2. The second is robustness to domain shifts (since we are preserving the majority of the original model’s weights rather than destructively re-learning them), which was mentioned in the prompt tuning paper but not so much in the LoRA paper.
    1. I could see this as a possible reason, but the gains in performance in the prompt tuning paper in the out-of-distribution setting are marginal at best, and LoRA doesn’t mention domain shifts.

(EDIT - I’m also wondering if there is there something else I’m missing to decrease GPU memory and runtime? I’ve heard QLoRA which adds 4-bit quantization of the model on top of LoRA, so perhaps that’s a way to tackle memory efficiency for LoRA. But I don’t know if there’s anything to reduce memory footprint for prompt tuning?)

  • lightSpeedBrickB
    link
    fedilink
    English
    arrow-up
    1
    ·
    10 months ago

    My understanding is that with LoRA you reduce the number of trainable parameters and therefore the memory needed to track optimizer states (e.g for Adam that tracks 2 state parameters for each model parameter). This means that you need far less RAM to fine-tune the model. Imagine 70B parameters * 4 bytes for fp32 training plus 70B * 8bytes for Adam. Lora reduces that second part to say 1% of 70B * 8 bytes.

    You can also use gradient checkpointing, which isn’t specific to LoRA, to reduce memory consumption at the expense of training time. Here you recompute activations during back-prop and cache some intermediate activations.

    Can you explain what you mean by “caching intermediate gradients during backprop”? I’m not familiar with what that is.

    • patricky168OPB
      link
      fedilink
      English
      arrow-up
      1
      ·
      10 months ago

      Yeah what I mean is that despite LoRA only updating gradients for the adapters on the attention weights, we still need to calculate gradients for downstream layers that aren’t being updated and that takes GPU memory. So the only memory saved is from the optimizer states if I am not mistaken.