Training Configuration

December 19, 2025 · View on GitHub

This document provides a comprehensive reference for all training and evaluation parameters available in ERNIEKit. It covers:

  • Basic model configuration and training setup
  • Evaluation metrics and strategies
  • Performance optimization techniques
  • Distributed training configurations
  • Memory optimization options
  • Checkpoint saving strategies
  • Acceleration methods
  • Mixed precision training settings
  • Specialized configurations for SFT, LoRA, DPO and FP8 training

Each parameter is documented with its type, default value and detailed description to help developers properly configure their training jobs.

1. General Configuration

1.1 Basic Configuration

ParameterTypeDefaultDescription
model_name_or_pathstrRequiredModel name or local model path for the model and tokenizer
hidden_dropout_probfloat0.0Dropout probability for hidden layers
attention_probs_dropout_probfloat0.0Dropout probability for attention layers
dropout_warmup_stepsint0Warmup steps for dropout. Dropout probability increases linearly during warmup and disables afterward. Set to 0 to disable dropout.
weight_quantize_algostrRequiredModel quantization algorithm. Options: weight_only_mix (expert weights as int4, other linear layers as int8) or weight_only_int8 (all linear layers as int8) or fp8_linear
output_dirstrRequiredDirectory to save model files, checkpoints, tokenizers, and evaluation results
logging_stepsintRequiredLogging interval. Decrease for more frequent updates.
logging_dirstrRequiredLog directory (defaults to output_dir if unspecified)
do_evalboolFalseEnable model evaluation
do_trainboolFalseEnable training
disable_tqdmboolFalseDisable tqdm progress bar for estimating total training time
continue_trainingboolTrueLoad pretrained weights to continue training
from_hf_hubboolFalseDownloading model from HuggingFace Hub
from_aistudioboolFalseDownloading model from Aistudio
from_modelscopeboolFalseDownloading model from ModelScope

1.2 Evaluation

ParameterTypeDefaultDescription
per_device_eval_batch_sizeintRequiredEvaluation batch size (micro batch size)
eval_dataset_pathstrRequiredPath to evaluation dataset (see sft-eval.jsonl
eval_dataset_probstr1.0Evaluation dataset sampling probability.
eval_dataset_typestrerniekitEvaluation dataset type.
eval_stepsintRequiredEvaluation interval steps
evaluation_strategystr"steps"Evaluation strategy. "steps" enables periodic evaluation
max_evaluate_stepsint1Maximum steps per evaluation (if positive)

1.3 Training Performance

ParameterTypeDefaultDescription
train_dataset_pathstrRequiredTraining dataset path (see sft-train.jsonl)
train_dataset_probstr1.0Training dataset sampling probability.
train_dataset_typestrerniekitTraining dataset type.
max_stepsintRequiredMaximum training steps (overrides num_train_epochs if set)
num_train_epochsintRequiredTraining epochs
per_device_train_batch_sizeintRequiredTraining batch size (micro batch size). Global batch size = DP sharding micro_batch_size * gradient_accumulation_steps
gradient_accumulation_stepsintRequiredGradient accumulation steps
weight_decayfloat0.0AdamW optimizer weight decay
seedint42Random seed
max_seq_lenintRequiredMaximum token length. Reduce if OOM occurs when increasing GBS.
learning_ratefloatRequiredLearning rate (SFT: 3e-5, DPO: 1e-6, SFT-LoRA: 3e-4, DPO-LoRA: 1e-5)
warmup_stepsintRequiredWarmup steps (typically 1%-10% of max_steps)
lr_scheduler_typestrlinearLearning rate scheduler (linear/cosine/polynomial/constant/constant_with_warmup)
min_lrfloat0.0Minimum learning rate (cosine scheduler only)
layerwise_lr_decay_boundfloat1.0Layerwise LR decay factor (0,1]. 1 means no decay.
random_shuffleboolTrueEnable dataset shuffling
num_cyclesfloat0.5Cosine scheduler: number of waves
lr_endfloat1e-7Polynomial scheduler: final LR
powerfloat1.0Polynomial scheduler: power
adam_beta1float0.9AdamW beta1
adam_beta2float0.999AdamW beta2
adam_epsilonfloat1e-8AdamW epsilon

1.4 Distributed Training

ParameterTypeDefaultDescription
tensor_parallel_degreeintRequiredTensor parallelism degree
tensor_parallel_configstrRequiredRecommended: "sync_param sync_grad sync_moment"
tensor_parallel_outputboolTrueEnable parallel output for last Transformer layer to save memory
pipeline_parallel_degreeintRequiredPipeline parallelism degree
pipeline_parallel_configstrRequiredRecommended: "disable_partial_send_recv enable_clear_every_step_cache enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler"
pp_seg_methodstrRequiredPipeline layer segmentation method
virtual_pp_degreeint1Virtual pipeline degree (effective when pipeline_parallel_degree > 1)
add_tail_layersint0Add EmptyLayers after DecodeLayer for virtual pipeline requirements
sharding_parallel_degreeintRequiredSharding parallelism degree
sharding_parallel_configstrRequiredRecommended: "enable_stage1_overlap enable_release_grads"
shardingstrRequiredSharding stage (stage1: optimizer, stage2: gradients, stage3: parameters)
sequence_parallelboolTrueEnable sequence parallelism
moe_groupstr"dummy"MoE communication group ("mp" for training, "dummy" for inference)

1.5 Memory Optimization

ParameterTypeDefaultDescription
release_gradsboolFalseRelease gradients after each iteration to reduce peak memory
use_sparse_head_and_loss_fnboolFalseUse sparse LM Head and loss function
use_fused_head_and_loss_fnboolFalseFuse LM head and CrossEntropyLoss to save memory
use_attn_mask_startend_row_indicesboolTrueUse sparse mask representation with start row indices
recompute_use_reentrantboolFalseRecompute implementation (PyLayer if True, hooks if False)
recomputeboolFalseEnable gradient checkpointing
recompute_granularitystr"full"Recompute granularity ("full"/"full_attn"/"core_attn")
offload_optimboolFalseOffload optimizer to CPU

1.6 Checkpoint

ParameterTypeDefaultDescription
save_stepsintRequiredCheckpoint save interval (when save_strategy=="steps")
save_strategystr"no"Checkpoint save strategy
unified_checkpointboolTrueUse unified checkpoint format
unified_checkpoint_configstr""See Unified Checkpoint
disable_ckpt_quantboolFalseSee Unified Checkpoint
ignore_save_lr_and_optimboolFalseSkip saving optimizer states
ignore_load_lr_and_optimboolFalseSkip loading optimizer states
save_total_limitintNoneMaximum number of checkpoints to keep

1.7 Acceleration

ParameterTypeDefaultDescription
use_flash_attentionboolTrueEnable FlashAttention
use_sparse_flash_attnboolTrueEnable FlashMask (requires use_attn_mask_startend_row_indices)
fuse_ropeboolFalseFuse rotary position embedding
fuse_linearboolFalseF fuse linear operations
greedy_intokensboolTrueEnable greedy token-based packing. Instead of sequential sampling, a global buffer of samples is maintained and greedily packed into sequences to maximize token utilization and minimize padding.
dataloader_num_workersint1Dataloader subprocess count (0 to disable)
distributed_dataloaderint0Use distributed dataloader for large datasets
moe_multimodal_dispatch_use_allgatherstrv2-alltoall-unpadOptimize MoE layer with allgather+unpad
ParameterTypeDefaultDescription
bf16boolFalseEnable BF16 training
fp16_opt_levelstrO1AMP level (O2 converts params to float16/bfloat16)
scale_lossint2 ** 15Loss scaling factor for float16
amp_custom_white_liststrRequiredAMP O2 whitelist (e.g., "lookup_table flash_attn matmul")
amp_custom_black_liststrRequiredAMP O2 blacklist (e.g., "reduce_sum elementwise_div")
amp_master_gradboolFalseMaintain float32 gradients for AMP O2

2. Specialized Configurations

2.1 SFT

ParameterTypeDefaultDescription
num_samples_each_epochint6000000Virtual epoch size (recommend keeping default)

2.2 LoRA

ParameterTypeDefaultDescription
lora_rankint8LoRA rank (typical: 8/16/32. Higher improves quality but increases memory)
lora_alphafloat-1LoRA scaling factor (scaling = alpha/rank or alpha/sqrt(rank) for rslora)
rsloraboolFalseEnable rslora scaling (recommended for rank ≥64)
lora_plus_scalefloat1LoRA+ learning rate multiplier (recommended: 4-16)
rslora_plusboolFalseEnhanced LoRA (improves performance but may cause forgetting)
loraboolFalseEnable LoRA training

2.3 DPO

ParameterTypeDefaultDescription
betafloat0.1DPO loss temperature
simpo_gammafloat0.5SimPO loss gamma
offset_alphafloat0.0Score-based DPO loss offset
max_prompt_lenint2048Maximum prompt length (truncated beyond max_seq_len-10)
loss_typestrsigmoidPreference loss type (sigmoid/ipo/kto_pair)
pref_loss_ratiofloat1.0Preference loss weight
sft_loss_ratiofloat0.0Chosen data SFT loss weight
label_smoothingfloat0.0Label smoothing for sigmoid loss
reference_freeboolFalseDisable reference model
ref_model_update_stepsint-1Reference model update interval (-1 to disable)

2.4 FP8 Training

ParameterTypeDefaultDescription
apply_hadamardboolTrueUse Hadamard transform for FP8 precision
use_lowprecision_momentboolFalseUse BF16 optimizer momentum (recommended for FP8)
tensorwise_offload_optimizerboolFalseOffload optimizer to reduce memory
apply_online_actscale_stepbool200Dynamic quantization scale steps
optim_shard_numint1Split optimizer state files during saving to avoid memory OOM. Works only when unified_checkpoint_config: ignore_merge_optimizer.