TRL documentation
GMPO
GMPO
In the paper Geometric-Mean Policy Optimization, the authors propose a GRPO variant that maximizes the geometric mean of the token-level importance ratios instead of the arithmetic mean. Because the geometric mean is far less sensitive to outlier ratios, the policy update is more stable and tolerates a much wider clipping range. Clipping is applied per token, in log space, and one-sided per the advantage sign (the standard PPO trust region) — crucially, before the geometric mean is taken.
To use GMPO, you can use the GMPOTrainer class in trl.experimental.gmpo.
Usage
from trl.experimental.gmpo import GMPOConfig, GMPOTrainer
training_args = GMPOConfig(
epsilon=0.4, # log-space clip range -> ratios clipped to (exp(-0.4), exp(0.4)); paper, Sec. 4
beta=0.0,
)
trainer = GMPOTrainer(
model="Qwen/Qwen3-0.6B",
reward_funcs=...,
train_dataset=...,
args=training_args,
)
trainer.train()In GMPO, clipping is applied to the per-token log-importance ratios (i.e. in log space) before the geometric mean is taken, so epsilon and epsilon_high are expressed in log space: the effective ratio clipping range is (exp(-epsilon), exp(epsilon_high)). The paper recommends a markedly wider range than GRPO/DAPO, (exp(-0.4), exp(0.4)), to encourage exploration.
GMPOTrainer
Trainer for Geometric-Mean Policy Optimization (GMPO).
GMPO (https://huggingface.co/papers/2507.20673) is a GRPO variant that maximizes the geometric mean of the token-level importance ratios instead of the arithmetic mean. Because the geometric mean is far less sensitive to outlier ratios, the policy update is more stable and a much wider clipping range can be used.
The only change w.r.t. GRPOTrainer is _compute_loss. Everything else (generation, reward computation, weight
syncing, metric logging) is inherited unchanged
train
< source >( resume_from_checkpoint: str | bool | None = Nonetrial: optuna.Trial | dict[str, Any] | None = Noneignore_keys_for_eval: list[str] | None = None ) → ~trainer_utils.TrainOutput
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: str | None = 'End of training'blocking: bool = Truetoken: str | None = Nonerevision: str | None = None**kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
GMPOConfig
class trl.experimental.gmpo.GMPOConfig
< source >( output_dir: str | None = Noneper_device_train_batch_size: int = 8num_train_epochs: float = 3.0max_steps: int = -1learning_rate: float = 1e-06lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear'lr_scheduler_kwargs: dict | str | None = Nonewarmup_steps: float = 0optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused'optim_args: str | None = Noneweight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08optim_target_modules: None | str | list[str] = Nonegradient_accumulation_steps: int = 1average_tokens_across_devices: bool = Truemax_grad_norm: float = 1.0label_smoothing_factor: float = 0.0bf16: bool | None = Nonefp16: bool = Falsebf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: bool | None = Nonegradient_checkpointing: bool = Truegradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = Nonetorch_compile: bool = Falsetorch_compile_backend: str | None = Nonetorch_compile_mode: str | None = Noneuse_liger_kernel: bool = Falseliger_kernel_config: dict[str, bool] | None = Noneuse_cache: bool = Falseneftune_noise_alpha: float | None = Nonetorch_empty_cache_steps: int | None = Noneauto_find_batch_size: bool = Falselogging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps'logging_steps: float = 10logging_first_step: bool = Falselog_on_each_node: bool = Truelogging_nan_inf_filter: bool = Trueinclude_num_input_tokens_seen: str | bool = 'no'log_level: str = 'passive'log_level_replica: str = 'warning'disable_tqdm: bool | None = Nonereport_to: None | str | list[str] = 'none'run_name: str | None = Noneproject: str = 'huggingface'trackio_space_id: str | None = Nonetrackio_bucket_id: str | None = Nonetrackio_static_space_id: typing.Union[str, NoneType, typing.Literal[False]] = Noneeval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no'eval_steps: float | None = Noneeval_delay: float = 0per_device_eval_batch_size: int = 8prediction_loss_only: bool = Falseeval_on_start: bool = Falseeval_do_concat_batches: bool = Trueeval_use_gather_object: bool = Falseeval_accumulation_steps: int | None = Noneinclude_for_metrics: list = <factory>batch_eval_metrics: bool = Falsesave_only_model: bool = Falsesave_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps'save_steps: float = 500save_on_each_node: bool = Falsesave_total_limit: int | None = Noneenable_jit_checkpoint: bool = Falsepush_to_hub: bool = Falsehub_token: str | None = Nonehub_private_repo: bool | None = Nonehub_model_id: str | None = Nonehub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save'hub_always_push: bool = Falsehub_revision: str | None = Noneload_best_model_at_end: bool = Falsemetric_for_best_model: str | None = Nonegreater_is_better: bool | None = Noneignore_data_skip: bool = Falserestore_callback_states_from_checkpoint: bool = Falsefull_determinism: bool = Falseseed: int = 42data_seed: int | None = Noneuse_cpu: bool = Falseaccelerator_config: dict | str | None = Noneparallelism_config: accelerate.parallelism_config.ParallelismConfig | None = Nonedataloader_drop_last: bool = Falsedataloader_num_workers: int = 0dataloader_pin_memory: bool = Truedataloader_persistent_workers: bool = Falsedataloader_prefetch_factor: int | None = Noneremove_unused_columns: bool | None = Falselabel_names: list[str] | None = Nonetrain_sampling_strategy: str = 'random'length_column_name: str = 'length'ddp_find_unused_parameters: bool | None = Noneddp_bucket_cap_mb: int | None = Noneddp_broadcast_buffers: bool | None = Noneddp_static_graph: bool | None = Noneddp_backend: str | None = Noneddp_timeout: int = 1800fsdp: str | None = Nonefsdp_config: dict[str, typing.Any] | str | None = Nonedeepspeed: dict | str | None = Nonedebug: str | list[transformers.debug_utils.DebugOption] = ''skip_memory_metrics: bool = Truedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseresume_from_checkpoint: str | None = Nonewarmup_ratio: float | None = Nonelogging_dir: str | None = Nonelocal_rank: int = -1model_init_kwargs: dict[str, typing.Any] | str | None = Nonetrust_remote_code: bool = Falserouter_aux_loss_coef: float = 0.001disable_dropout: bool = Falsecast_lm_head_to_fp32: bool = Falsenum_generations: int | None = 8num_generations_eval: int | None = Nonemax_completion_length: int | None = 256ds3_gather_for_generation: bool = Trueshuffle_dataset: bool | None = Truepad_to_multiple_of: int | None = Nonegeneration_batch_size: int | None = Nonesteps_per_generation: int | None = Nonetemperature: float = 1.0top_p: float = 1.0top_k: int = 0min_p: float | None = Nonegeneration_kwargs: dict | None = Nonechat_template_kwargs: dict | None = Nonerepetition_penalty: float = 1.0cache_implementation: str | None = Noneuse_vllm: bool = Falsevllm_mode: str = 'colocate'vllm_model_impl: str = 'vllm'vllm_enable_sleep_mode: bool = Falsevllm_structured_outputs_regex: str | None = Nonevllm_server_base_url: str | None = Nonevllm_server_host: str = '0.0.0.0'vllm_server_port: int = 8000vllm_server_timeout: float = 240.0vllm_group_port: int = 51216vllm_gpu_memory_utilization: float = 0.3vllm_max_model_length: int | None = Nonevllm_tensor_parallel_size: int = 1beta: float = 0.0num_iterations: int = 1epsilon: float = 0.4delta: float | None = Noneepsilon_high: float | None = Nonesapo_temperature_neg: float = 1.05sapo_temperature_pos: float = 1.0vespo_k_pos: float = 2.0vespo_lambda_pos: float = 3.0vespo_k_neg: float = 3.0vespo_lambda_neg: float = 2.0importance_sampling_level: str = 'token'reward_weights: list[float] | None = Nonemulti_objective_aggregation: str = 'sum_then_normalize'scale_rewards: str = 'group'loss_type: str = 'dapo'mask_truncated_completions: bool = Falsesync_ref_model: bool = Falseref_model_mixup_alpha: float = 0.6ref_model_sync_steps: int = 512top_entropy_quantile: float = 1.0max_tool_calling_iterations: int | None = Nonevllm_importance_sampling_correction: bool = Truevllm_importance_sampling_mode: str = 'sequence_mask'vllm_importance_sampling_clip_max: float | None = 3.0vllm_importance_sampling_clip_min: float | None = Noneoff_policy_mask_threshold: float | None = Noneuse_bias_correction_kl: bool = Falselog_completions: bool = Falsenum_completions_to_print: int | None = Nonelog_unique_prompts: bool = Falselog_completions_hub_repo: str | None = Noneuse_transformers_continuous_batching: bool = Falsetransformers_continuous_batching_config: dict | None = Noneuse_transformers_paged: bool = Falsevllm_importance_sampling_cap: float | None = None )
Parameters
- epsilon (
float, optional, defaults to0.4) — Lower-bound clipping value, expressed in log space. The lower bound of the per-token importance ratio isexp(-epsilon). - epsilon_high (
float, optional) — Upper-bound clipping value, expressed in log space. IfNone, it defaults to the value ofepsilon. The upper bound of the per-token importance ratio isexp(epsilon_high).
Configuration class for the GMPOTrainer.
GMPOConfig inherits every parameter from GRPOConfig; it only changes the meaning and default of the
clipping range. In GMPO, clipping is applied to the per-token log-importance ratios (i.e. in log space) before
the geometric mean is taken, so epsilon and epsilon_high are expressed in log space: the effective ratio
clipping range is (exp(-epsilon), exp(epsilon_high)). The GMPO paper recommends a markedly wider range than GRPO/DAPO, (exp(-0.4), exp(0.4)), to encourage exploration.