README.md

February 23, 2024 ยท View on GitHub

Fine-tuning on MMedBench Trainset

we provide the code for further fine-tuning on the MMedBench Trainset, along with all the hyperparameters used in our experiments. We employed two fine-tuning methods, namely Full model fine-tuning and PEFT fine-tuning.

Full Model Fine-tuning

Full model fine-tuning typically enables the model to achieve better results, but due to the 7B LLM being too large to fit on a single A100 80GB GPU, it is necessary to use FSDP (Fully Sharded Data Parallel) technology to distribute the model across multiple GPUs. In our experiments, we used 4 A100 80GB GPUs.

Methodglobal_batch_sizelearning_ratefsdp_transformer_layer_cls_to_wrap
BLOOMZ1281e-6BloomBlock
InternLM1281e-6InternLMDecoderLayer
Llama 21281e-6LlamaDecoderLayer
MedAlpaca1281e-6LlamaDecoderLayer
ChatDoctor1281e-6LlamaDecoderLayer
PMC-LLaMA1281e-6LlamaDecoderLayer
Mistral1281e-6MistralDecoderLayer
InternLM 21281e-6InternLM2DecoderLayer
MMedLM1281e-6InternLMDecoderLayer
MMedLM 21281e-6InternLM2DecoderLayer

For full model finetuning, you should download the model weight weights first, and change the following parameters in fullmodel_finetuning.sh:

  • MODEL_NAME_OR_PATH
  • OUTPUT_DIR: Directory to save the checkpoints
  • TRANSFORMER_LAYER: fsdp_transformer_layer_cls_to_wrap, used for FSDP.

Then you can start full model finetuning by

sbatch fullmodel_finetuning.sh

PEFT Fine-tuning (LoRA)

PEFT fine-tuning is generally more efficient (faster speed, less memory usage) for fine-tuning models because it freezes most of the parameters of the LLM and only updates a small number of parameters. Fine-tuning the model using LoRA allows the model to be placed on a single A100 80GB GPU. In our experiments, we employed the DDP (Distributed Data Parallel) parallel strategy, using 4 A100 80GB GPUs.

Methodglobal_batch_sizelearning_ratetarget_modules
BLOOMZ1281e-6["query_key_value"]
InternLM1281e-6["q_proj", "v_proj"]
Llama 21281e-6["q_proj", "v_proj"]
MedAlpaca1281e-6["q_proj", "v_proj"]
ChatDoctor1281e-6["q_proj", "v_proj"]
PMC-LLaMA1281e-6["q_proj", "v_proj"]
Mistral1281e-6["q_proj", "v_proj"]
InternLM 21281e-6["wqkv"]
MMedLM1281e-6["q_proj", "v_proj"]
MMedLM 21281e-6["wqkv"]

For full model finetuning, you should download the model weight weights first, and change the following parameters in fullmodel_finetuning.sh:

  • MODEL_NAME_OR_PATH
  • OUTPUT_DIR: Directory to save the checkpoints
  • TARGET_MODUILES: target_modules, used for LoRA.

Then you can start LoRA finetuning by

sbatch LoRA_finetuning.sh