TRL documentation

GMPO

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GMPO

In the paper Geometric-Mean Policy Optimization, the authors propose a GRPO variant that maximizes the geometric mean of the token-level importance ratios instead of the arithmetic mean. Because the geometric mean is far less sensitive to outlier ratios, the policy update is more stable and tolerates a much wider clipping range. Clipping is applied per token, in log space, and one-sided per the advantage sign (the standard PPO trust region) — crucially, before the geometric mean is taken.

To use GMPO, you can use the GMPOTrainer class in trl.experimental.gmpo.

Usage

from trl.experimental.gmpo import GMPOConfig, GMPOTrainer

training_args = GMPOConfig(
    epsilon=0.4,  # log-space clip range -> ratios clipped to (exp(-0.4), exp(0.4)); paper, Sec. 4
    beta=0.0,
)
trainer = GMPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=...,
    train_dataset=...,
    args=training_args,
)
trainer.train()

In GMPO, clipping is applied to the per-token log-importance ratios (i.e. in log space) before the geometric mean is taken, so epsilon and epsilon_high are expressed in log space: the effective ratio clipping range is (exp(-epsilon), exp(epsilon_high)). The paper recommends a markedly wider range than GRPO/DAPO, (exp(-0.4), exp(0.4)), to encourage exploration.

GMPOTrainer

class trl.experimental.gmpo.GMPOTrainer

< >

( modelreward_funcsargs = None**kwargs )

Trainer for Geometric-Mean Policy Optimization (GMPO).

GMPO (https://huggingface.co/papers/2507.20673) is a GRPO variant that maximizes the geometric mean of the token-level importance ratios instead of the arithmetic mean. Because the geometric mean is far less sensitive to outlier ratios, the policy update is more stable and a much wider clipping range can be used.

The only change w.r.t. GRPOTrainer is _compute_loss. Everything else (generation, reward computation, weight syncing, metric logging) is inherited unchanged

train

< >

( resume_from_checkpoint: str | bool | None = Nonetrial: optuna.Trial | dict[str, Any] | None = Noneignore_keys_for_eval: list[str] | None = None ) ~trainer_utils.TrainOutput

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

Returns

~trainer_utils.TrainOutput

Object containing the global step count, training loss, and metrics.

Main training entry point.

save_model

< >

( output_dir: str | None = None_internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

push_to_hub

< >

( commit_message: str | None = 'End of training'blocking: bool = Truetoken: str | None = Nonerevision: str | None = None**kwargs )

Parameters

  • commit_message (str, optional, defaults to "End of training") — Message to commit while pushing.
  • blocking (bool, optional, defaults to True) — Whether the function should return only when the git push has finished.
  • token (str, optional, defaults to None) — Token with write permission to overwrite Trainer’s original args.
  • revision (str, optional) — The git revision to commit from. Defaults to the head of the “main” branch.
  • kwargs (dict[str, Any], optional) — Additional keyword arguments passed along to ~Trainer.create_model_card.

Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.

GMPOConfig

class trl.experimental.gmpo.GMPOConfig

< >

( output_dir: str | None = Noneper_device_train_batch_size: int = 8num_train_epochs: float = 3.0max_steps: int = -1learning_rate: float = 1e-06lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear'lr_scheduler_kwargs: dict | str | None = Nonewarmup_steps: float = 0optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused'optim_args: str | None = Noneweight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08optim_target_modules: None | str | list[str] = Nonegradient_accumulation_steps: int = 1average_tokens_across_devices: bool = Truemax_grad_norm: float = 1.0label_smoothing_factor: float = 0.0bf16: bool | None = Nonefp16: bool = Falsebf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: bool | None = Nonegradient_checkpointing: bool = Truegradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = Nonetorch_compile: bool = Falsetorch_compile_backend: str | None = Nonetorch_compile_mode: str | None = Noneuse_liger_kernel: bool = Falseliger_kernel_config: dict[str, bool] | None = Noneuse_cache: bool = Falseneftune_noise_alpha: float | None = Nonetorch_empty_cache_steps: int | None = Noneauto_find_batch_size: bool = Falselogging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps'logging_steps: float = 10logging_first_step: bool = Falselog_on_each_node: bool = Truelogging_nan_inf_filter: bool = Trueinclude_num_input_tokens_seen: str | bool = 'no'log_level: str = 'passive'log_level_replica: str = 'warning'disable_tqdm: bool | None = Nonereport_to: None | str | list[str] = 'none'run_name: str | None = Noneproject: str = 'huggingface'trackio_space_id: str | None = Nonetrackio_bucket_id: str | None = Nonetrackio_static_space_id: typing.Union[str, NoneType, typing.Literal[False]] = Noneeval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no'eval_steps: float | None = Noneeval_delay: float = 0per_device_eval_batch_size: int = 8prediction_loss_only: bool = Falseeval_on_start: bool = Falseeval_do_concat_batches: bool = Trueeval_use_gather_object: bool = Falseeval_accumulation_steps: int | None = Noneinclude_for_metrics: list = <factory>batch_eval_metrics: bool = Falsesave_only_model: bool = Falsesave_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps'save_steps: float = 500save_on_each_node: bool = Falsesave_total_limit: int | None = Noneenable_jit_checkpoint: bool = Falsepush_to_hub: bool = Falsehub_token: str | None = Nonehub_private_repo: bool | None = Nonehub_model_id: str | None = Nonehub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save'hub_always_push: bool = Falsehub_revision: str | None = Noneload_best_model_at_end: bool = Falsemetric_for_best_model: str | None = Nonegreater_is_better: bool | None = Noneignore_data_skip: bool = Falserestore_callback_states_from_checkpoint: bool = Falsefull_determinism: bool = Falseseed: int = 42data_seed: int | None = Noneuse_cpu: bool = Falseaccelerator_config: dict | str | None = Noneparallelism_config: accelerate.parallelism_config.ParallelismConfig | None = Nonedataloader_drop_last: bool = Falsedataloader_num_workers: int = 0dataloader_pin_memory: bool = Truedataloader_persistent_workers: bool = Falsedataloader_prefetch_factor: int | None = Noneremove_unused_columns: bool | None = Falselabel_names: list[str] | None = Nonetrain_sampling_strategy: str = 'random'length_column_name: str = 'length'ddp_find_unused_parameters: bool | None = Noneddp_bucket_cap_mb: int | None = Noneddp_broadcast_buffers: bool | None = Noneddp_static_graph: bool | None = Noneddp_backend: str | None = Noneddp_timeout: int = 1800fsdp: str | None = Nonefsdp_config: dict[str, typing.Any] | str | None = Nonedeepspeed: dict | str | None = Nonedebug: str | list[transformers.debug_utils.DebugOption] = ''skip_memory_metrics: bool = Truedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseresume_from_checkpoint: str | None = Nonewarmup_ratio: float | None = Nonelogging_dir: str | None = Nonelocal_rank: int = -1model_init_kwargs: dict[str, typing.Any] | str | None = Nonetrust_remote_code: bool = Falserouter_aux_loss_coef: float = 0.001disable_dropout: bool = Falsecast_lm_head_to_fp32: bool = Falsenum_generations: int | None = 8num_generations_eval: int | None = Nonemax_completion_length: int | None = 256ds3_gather_for_generation: bool = Trueshuffle_dataset: bool | None = Truepad_to_multiple_of: int | None = Nonegeneration_batch_size: int | None = Nonesteps_per_generation: int | None = Nonetemperature: float = 1.0top_p: float = 1.0top_k: int = 0min_p: float | None = Nonegeneration_kwargs: dict | None = Nonechat_template_kwargs: dict | None = Nonerepetition_penalty: float = 1.0cache_implementation: str | None = Noneuse_vllm: bool = Falsevllm_mode: str = 'colocate'vllm_model_impl: str = 'vllm'vllm_enable_sleep_mode: bool = Falsevllm_structured_outputs_regex: str | None = Nonevllm_server_base_url: str | None = Nonevllm_server_host: str = '0.0.0.0'vllm_server_port: int = 8000vllm_server_timeout: float = 240.0vllm_group_port: int = 51216vllm_gpu_memory_utilization: float = 0.3vllm_max_model_length: int | None = Nonevllm_tensor_parallel_size: int = 1beta: float = 0.0num_iterations: int = 1epsilon: float = 0.4delta: float | None = Noneepsilon_high: float | None = Nonesapo_temperature_neg: float = 1.05sapo_temperature_pos: float = 1.0vespo_k_pos: float = 2.0vespo_lambda_pos: float = 3.0vespo_k_neg: float = 3.0vespo_lambda_neg: float = 2.0importance_sampling_level: str = 'token'reward_weights: list[float] | None = Nonemulti_objective_aggregation: str = 'sum_then_normalize'scale_rewards: str = 'group'loss_type: str = 'dapo'mask_truncated_completions: bool = Falsesync_ref_model: bool = Falseref_model_mixup_alpha: float = 0.6ref_model_sync_steps: int = 512top_entropy_quantile: float = 1.0max_tool_calling_iterations: int | None = Nonevllm_importance_sampling_correction: bool = Truevllm_importance_sampling_mode: str = 'sequence_mask'vllm_importance_sampling_clip_max: float | None = 3.0vllm_importance_sampling_clip_min: float | None = Noneoff_policy_mask_threshold: float | None = Noneuse_bias_correction_kl: bool = Falselog_completions: bool = Falsenum_completions_to_print: int | None = Nonelog_unique_prompts: bool = Falselog_completions_hub_repo: str | None = Noneuse_transformers_continuous_batching: bool = Falsetransformers_continuous_batching_config: dict | None = Noneuse_transformers_paged: bool = Falsevllm_importance_sampling_cap: float | None = None )

Parameters

  • epsilon (float, optional, defaults to 0.4) — Lower-bound clipping value, expressed in log space. The lower bound of the per-token importance ratio is exp(-epsilon).
  • epsilon_high (float, optional) — Upper-bound clipping value, expressed in log space. If None, it defaults to the value of epsilon. The upper bound of the per-token importance ratio is exp(epsilon_high).

Configuration class for the GMPOTrainer.

GMPOConfig inherits every parameter from GRPOConfig; it only changes the meaning and default of the clipping range. In GMPO, clipping is applied to the per-token log-importance ratios (i.e. in log space) before the geometric mean is taken, so epsilon and epsilon_high are expressed in log space: the effective ratio clipping range is (exp(-epsilon), exp(epsilon_high)). The GMPO paper recommends a markedly wider range than GRPO/DAPO, (exp(-0.4), exp(0.4)), to encourage exploration.

Update on GitHub