verl: Flexible and Efficient RL for LLMs

Yuxuan Tong (童雨轩)

ByteDance Seed & Tsinghua University

1 Motivation: Why is Large-Scale RL Important?

A good framework solves an important problem.

1.1 Learning to Reason with Large-Scale RL

Table 1: Learning to reason with large-scale RL significantly boosts the performance of LLMs.
Model Large-Scale RL? AIME 2024 MATH 500 GPQA Diamond Code Forces
GPT-4o (OpenAI 2024) 44.6 60.3 50.6 >11.0%
o1 (OpenAI 2024) 74.4 94.8 77.3 >89.0%
R1 (DeepSeek-AI 2025) 79.8 97.3 71.5 >96.3%

1.2 Learning as Agent with Large-Scale RL

OpenAI (2025):

Deep research independently discovers, reasons about, and consolidates insights from across the web.

To accomplish this, it was trained on real-world tasks requiring browser and Python tool use,

using the same reinforcement learning methods behind OpenAI o1, our first reasoning model.

Check OpenAI Deep Research’s demo video for more details.

2 Challenge: Why is Large-Scale RL Challenging?

A good framework solve a challenging problem.

2.1 RL is Complex Dataflow

Reinforcement Learning (RL) can be modelled as complex dataflow graph (Schaarschmidt et al. 2019; Liang et al. 2021; Sheng et al. 2025), consisting of:

  1. multiple models: actor, critic, reference, reward model, etc.
  2. multiple stages: generating, preparing experiences, training
  3. multiple workloads: generation, inference, training

2.2 LLM Workloads Are Distributed

LLM workloads often involves:

  • many GPUs
  • complex parallelism strategies

2.3 RL with LLMs is Large-Scale Distributed Dataflow

2.4 Constraints: Data Dependencies & Resource Limitations

3 Why verl for RL with LLMs?

3.1 Flexibility: “Single-Controller”

Listing 1: PPO core code in a few lines in verl.
for prompts in dataloader:
    # Stage 1: Generation
    batch = actor.generate_sequences(prompts)
    # Stage 2: Experience Preparation
    batch = reward.compute_reward(batch)
    batch = reference.compute_log_prob(batch)
    batch = critic.compute_values(batch)
    batch = compute_advantage(batch, "gae")
    # Stage 3: Training
    critic.update_critic(batch)
    actor.update_actor(batch)
  • Programming interface based on the “single-controller” paradigm
  • RL algorithm core logic in a few lines of code!
  • Diverse RL algorithms supported: PPOGRPORLOO, ReMaxPRIMEDAPO, etc.

3.2 Efficiency: “Multi-Controller”

verl is efficient for intra-operator with the “multi-controller” paradigm and features like:

Parallelism Algorithms:

  • Data Parallelism
  • Tensor Parallelism
  • Pipeline Parallelism
  • Context / Sequence Parallelism

Efficient Kernels:

  • Flash Attention
  • Torch Compile
  • Liger Kernel

Training Backends:

  • FSDP
  • FSDP2
  • Megatron

Generation Backends:

  • vLLM
  • SGLang

3.3 Efficiency: “Hybrid Engine”

verl is efficient for inter-operator with the “hybrid engine” paradigm, utilizing features like:

  • offloading & reloading enables fully utilizing the GPU memory
  • resharding enables switching for the optimal parallelism strategy

3.4 Open-Source Community

3.4.1 Extensive Participation

So far, verl has gained:

  • 8.4k+ stars
  • 1k+ forks
  • ~900 PRs
  • ~200 contributors
  • … Waiting for your participation!

3.4.2 Easy for Extension

4 Paradigm behind verl: HybridFlow (Sheng et al. 2025)

4.1 Background: Single-Controller vs. Multi-Controller

Figure 7: Single-Controller (Multi-Program-Multi-Data) vs. Multi-Controller (Single-Program-Multi-Data) (Barham et al. 2022)
  • Single-Controller (MPMD): A centralized controller manages all the workers, running different programs.
  • Multi-Controller (SPMD): Each worker has its own controller, running the same program with different data.

4.2 Trade-off: Single-Controller or Multi-Controller?

Table 2: Trade-off between single-controller and multi-controller.
Paradigm Pro Con
Single-Controller Flexible Communication Overhead
Multi-Controller Efficient Complex Programming

🤔 Which paradigm should we choose?

🤩 Actually, we can have “both”!

4.3 New Paradigm: Hybrid-Controller!

💡 Hybrid-Controller = Single-Controller + N x Multi-Controller

4.4 Implementation in verl

Each call in the single-controller (e.g. critic.compute_values, actor.update_actor) is an RPC to a multi-controller worker group.

Listing 2: PPO core code in single-controller.
for prompts in dataloader:
    # Stage 1: Generation
    batch = actor.generate_sequences(prompts)
    # Stage 2: Experience Preparation
    batch = reward.compute_reward(batch)
    batch = reference.compute_log_prob(batch)
    batch = critic.compute_values(batch)
    batch = compute_advantage(batch, "gae")
    # Stage 3: Training
    critic.update_critic(batch)
    actor.update_actor(batch)
Listing 3: Example distributed code in multi-controller.
class CriticWorker(3DParallelWorker):
  @register(dispatch_mode=3D_PROTO)
  def compute_values(self, batch: DataProto):
      values = self.critic.forward(batch)
      batch.update(values=values)
# ...
class ActorWorker(3DParallelWorker):
  @register(dispatch_mode=3D_PROTO)
  def update_actor(self, batch: DataProto):
      loss = self.actor(batch)
      loss.backward()

The register decorator utility manages the distributed data transfer, which also makes multi-controller programming easier.

5 Approaching More Scalable Agentic RL

5.1 Async Engine for Multi-Turn Rollout

  • Synchronous Engine: returns all the outputs in the batch at the same time
  • Asynchronous Engine: returns each output as soon as it is ready

5.2 Basic Capability Support

  1. Multi-Modal: "images" & "videos" fields in dataset
  2. Tool: Extensible interface BaseTool

5.3 Diverse Environments & Tools (Ongoing)

Welcome to discuss about / contribute to:

  1. Our ongoing RFC #1172
  2. Integrating protocols like MCP
  3. Integrating existing environments & tools, e.g.,

6 Recent Updates & Roadmap

6.1 Efficient RL with Huge MoE like DeepSeek-V3-671B (ETA: Late May’25)

verl is working on supporting efficient RL training for huge MoE like DeepSeek-V3-671B, based on the following features:

  1. Training: MoE models classes based on Megatron GPTModel
  2. Inference: Multi-node inference
  3. Hybrid: Parameter sharding manager for Megatron-Core V0.12 + latest version of inference engines

For more details, please check our tracker #708.

6.2 Other Plans

We have also received many meaningful feature requests from the community, e.g.,

  1. Partial Rollout (Kimi Team 2025)
  2. Multi-Token-Prediction (MTP) (Gloeckle et al. 2024)

For the most timely updates of important features, please keep an eye on verl’s Roadmap.

Thanks for Listening!

Welcome to join the verl community to discuss / contribute!

💻 Code Repository @ https://github.com/volcengine/verl

❓ Further Questions @ tongyuxuan361@gmail.com

💼 We Are Recruiting! @ haibin.lin@bytedance.com

References

Barham, Paul, Aakanksha Chowdhery, Jeff Dean, Sanjay Ghemawat, Steven Hand, Daniel Hurt, Michael Isard, et al. 2022. “Pathways: Asynchronous Distributed Dataflow for ML.” Proceedings of Machine Learning and Systems 4 (April): 430–49. https://proceedings.mlsys.org/paper_files/paper/2022/hash/37385144cac01dff38247ab11c119e3c-Abstract.html.
Dai, Josef, Xuehai Pan, Ruiyang Sun, Jiaming Ji, Xinbo Xu, Mickel Liu, Yizhou Wang, and Yaodong Yang. 2023. “Safe RLHF: Safe Reinforcement Learning from Human Feedback.” In. https://openreview.net/forum?id=TyFrPOKYXw.
Dakota Mahan, Teknium, Roger Jin. 2025. Atropos - An Async First Environment Rollout Controller.” https://www.github.com/NousResearch/Atropos.
DeepSeek-AI. 2025. “DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning.” https://arxiv.org/abs/2501.12948.
Gloeckle, Fabian, Badr Youbi Idrissi, Baptiste Roziere, David Lopez-Paz, and Gabriel Synnaeve. 2024. “Better & Faster Large Language Models via Multi-Token Prediction.” In Forty-First International Conference on Machine Learning. https://openreview.net/forum?id=pEWAcejiU2.
Kimi Team. 2025. “Kimi K1.5: Scaling Reinforcement Learning with LLMs.” https://arxiv.org/abs/2501.12599.
Li, Ziniu, Tian Xu, Yushun Zhang, Zhihang Lin, Yang Yu, Ruoyu Sun, and Zhi-Quan Luo. 2024. ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models.” In. https://openreview.net/forum?id=Stn8hXkpe6.
Liang, Eric, Zhanghao Wu, Michael Luo, Sven Mika, Joseph E Gonzalez, and Ion Stoica. 2021. “Rllib Flow: Distributed Reinforcement Learning Is a Dataflow Problem.” Advances in Neural Information Processing Systems 34: 5506–17.
OpenAI. 2024. “Learning to Reason with LLMs.” OpenAI Blog. https://openai.com/index/learning-to-reason-with-llms/.
———. 2025. “Introducing Deep Research.” OpenAI Blog. https://openai.com/index/introducing-deep-research/.
Schaarschmidt, Michael, Sven Mika, Kai Fricke, and Eiko Yoneki. 2019. “Rlgraph: Modular Computation Graphs for Deep Reinforcement Learning.” Proceedings of Machine Learning and Systems 1: 65–80.
Schulman, John, Filip Wolski, Prafulla Dhariwal, Alec Radford, and Oleg Klimov. 2017. “Proximal Policy Optimization Algorithms.” https://arxiv.org/abs/1707.06347.
Sheng, Guangming, Chi Zhang, Zilingfeng Ye, Xibin Wu, Wang Zhang, Ru Zhang, Yanghua Peng, Haibin Lin, and Chuan Wu. 2025. “HybridFlow: A Flexible and Efficient RLHF Framework.” In Proceedings of the 20th European Conference on Computer Systems. EuroSys ’25. Rotterdam, The Netherlands: ACM.
Shi, Jiajun, Jian Yang, Jiaheng Liu, Xingyuan Bu, Jiangjie Chen, Junting Zhou, Kaijing Ma, et al. 2025. “KORGym: A Dynamic Game Platform for LLM Reasoning Evaluation.” https://arxiv.org/abs/2505.14552.

Appendix

7 Important Features of verl

7.1 Sequence Packing

  1. Remove padding tokens and packs multiple data sequences into a row
  2. Tweak the attention mask & position IDs to avoid cross-contamination

To enable this, use use_remove_padding.

7.2 DP Balancing

7.2.1 Load Imbalance in DP

  • Parallelism usually needs synchronization between different ranks.
  • Data Parallelism (DP) like ZeRO is the most commonly used parallelism strategy.
  • However, DP performance might be damaged by load imbalance, which is especially severe in long-context training.

7.2.2 Balancing across DP Ranks

  • balance the valid tokens dispatched to each rank
  • by reordering the samples in each batch

To enable this, use balance_batch.

7.2.3 Balancing across Micro Batches

However, in gradient accumulation,

  • it’s not enough to only balance valid tokens in a batch,
  • since DP syncs in the unit of micro batch.

To resolve this, verl further supports to

  • balance the valid tokens across micro batches
  • by evenly deviding the data sequences in the batch before packing into micro batches

To enable this, use use_dynamic_bsz.

7.3 Other Features

  1. Full support for RL with AMD (ROCm Kernel) hardwares
  2. Optimizations: Gradient Checkpointing, Torch Compile, Liger Kernel, etc.

8 Programming Guide

8.1 Customizing the Dataset

A canonical RL dataset in verl has the following fields:

  • prompt: a list of messages {"role": "...", "content": "..."}
  • data_source: used to choose the reward function
  • reward_model: a dict containing
    • "ground_truth"
    • "style" like "model" or "rule"
  • (Optional) extra_info: a dict containing extra information

For VLM RL, verl expects fields "images" and/or "videos"

For examples, please check the examples/data_preprocess.

You could also customize the field names via config. Please check the data section in config files like ppo_trainer.yaml for more details.

For further customization, verl provides the data.custom_cls config,

Listing 4: Config for custom dataset class.
data:
  custom_cls:
    path: null # path to the `.py` file containing the `class` definition
    name: null # the `class` name

An example CLI config could be:

Listing 5: Example config for custom dataset class.
--data.custom_cls.path=./examples/dataset/custom_dataset.py \
--data.custom_cls.name=CustomDataset

The custom dataset class defined in the .py file is required to accept the following initialization parameters:

Listing 6: Custom dataset class initialization.
class CustomDataset: # You could also inherit from `RLHFDataset`
  def __init__(
      self,
      data_files: Union[str, List[str]],
      tokenizer: PreTrainedTokenizer,
      config: DictConfig,
      processor: Optional[ProcessorMixin] = None,
  ):
      ...

8.2 Customizing the Reward

verl allows to define custom reward function via the custom_reward_function config:

Listing 7: Config for custom reward function.
custom_reward_function:
  path: null # path to the `.py` file containing the function definition
  name: compute_score # the function name after `def`
reward_model:
  reward_manager: naive

An example CLI config could be:

Listing 8: Example config for custom reward function.
--custom_reward_function.path=./examples/reward_fn/custom_reward_fn.py \
--custom_reward_function.name=compute_score \
--reward_model.reward_manager=naive

The function defined in .py should accept the parameters passed from the reward manager __call__ method. Taking NaiveRewardManager as an example:

Listing 9: How a reward function is called in NaiveRewardManager.
class NaiveRewardManager:
    def __call__(self, data: DataProto, return_dict: bool=False):
        # Preprocessing for the input data
        score = self.compute_score(
            data_source=data_source,
            solution_str=solution_str,
            ground_truth=ground_truth,
            extra_info=extra_info,
        )
        # Other processing for the final `reward`

For more complex features, you can also add a new reward manager like PRIMERewardManager or DAPORewardManager.

8.3 Customizing the Loss Function

To modify the loss function, the most convenient way is to

  1. search for the .backward() call
  2. modify functions like compute_policy_loss
  3. or add loss terms like entropy_loss

For example, the DataParallelPPOActor.update_policy method defines the loss function as follows:

Listing 10: Simplified loss function definition in DataParallelPPOActor.
class DataParallelPPOActor(BasePPOActor):
    def update_policy(self, data: DataProto):
        pg_loss = compute_policy_loss(
            old_log_prob=old_log_prob, log_prob=log_prob,
            advantages=advantages, # ...
        )
        entropy_loss = agg_loss(loss_mat=entropy)
        policy_loss = pg_loss - entropy_loss * entropy_coeff
        kld = kl_penalty(
            logprob=log_prob, ref_logprob=ref_log_prob, # ...
        )
        kl_loss = agg_loss(loss_mat=kld)
        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
        loss.backward()

8.4 Customizing the Training Logic

As mentioned above, the main training logic is concentrated in the fit function of the trainer classes like RayPPOTrainer.

For example, the DAPORayTrainer class overrides the fit function to implement the “dynamic sampling” feature:

(See the next slide for the code ➡️)

Listing 11: Simplified fit function in DAPORayTrainer, with dynamic sampling highlighted.
class RayDAPOTrainer(RayPPOTrainer):
  def fit(self):
    for epoch in range(self.config.trainer.total_epochs):
      batch = None
      for batch_dict in self.train_dataloader:
        new_batch = DataProto.from_single_dict(batch_dict)
        num_gen_batches += 1
        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
        new_batch = new_batch.union(gen_batch_output)
        if not self.config.algorithm.filter_groups.enable:
          batch = new_batch
        else:
          # Getting `kept_traj_idxs` ...
          new_batch = new_batch[kept_traj_idxs]
          batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
          prompt_bsz = self.config.data.train_batch_size
          if num_prompt_in_batch < prompt_bsz:
            max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
            if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
                continue
          else:
            traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
            batch = batch[:traj_bsz]
        # ...

9 About

9.1 Presenter Contact