LoRA Done RITE: Robust Invariant Transformation Equilibration for LoRA Optimization

October 19, 2025 ยท View on GitHub

This is a pytorch reimplementation of the original LoRA-RITE in Jax.

Usage

Please copy lora_rite.py to your directory or install it as a module.

Then you can do the following to create a normal pytorch optimizer object.

from lora_rite import LoRARite

lora_params = [p for n, p in model.named_parameters() if "lora" in n]
optimizer = LoRARite(lora_params, lr=learning_rate, betas=(0.9,0.999), clip_unmagnified_grad=max_grad_norm)

Here we assume the lora parameters will be in an alternating order lora_a_1, lora_b_1, lora_a_2, lora_b_2, ... as in the huggingface peft LoRA implementation. In the rare case where this assumption is not satisfied, one can manually reorder it so that the assumption is met.

To correctly adopt gradient clipping, please set the clip_unmagnified_grad parameter for LoRARite and disable it elsewhere (e.g., by setting max_grad_norm=0 for the huggingface trainer).

Commonsense Reasoning Evaluation

This setting is significantly different from what is used in the paper due to the potentially high amount of effort needed to align the environments of pytorch and JAX. We adopt the recipe from the LLM-adapter paper, where the datasets are highly overlapped with our original experiments.

Gemma-2B Result

OptimizerBOOLQPIQASIQAHellaSwagWinograndeARC-EARC-COBQAAverage
LoRARite62.9174.8667.5069.3062.0478.4562.2968.8068.27
Adam62.2075.4665.3567.3855.8076.6058.7068.0066.19

Running the Experiments

cd LLM-Adapters

# evaluate existing lora_rite checkpoint
bash eval_commonsense.sh

# finetune with lora_rite
bash finetune_commonsense.sh