2025/06/30
A good framework solves an important problem.
Model | Large-Scale RL? | AIME 2024 | MATH 500 | GPQA Diamond | Code Forces |
---|---|---|---|---|---|
GPT-4o (OpenAI 2024) | ❌ | 44.6 | 60.3 | 50.6 | >11.0% |
o1 (OpenAI 2024) | ✅ | 74.4 | 94.8 | 77.3 | >89.0% |
R1 (DeepSeek-AI 2025) | ✅ | 79.8 | 97.3 | 71.5 | >96.3% |
OpenAI (2025):
Deep research independently discovers, reasons about, and consolidates insights from across the web.
To accomplish this, it was trained on real-world tasks requiring browser and Python tool use,
using the same reinforcement learning methods behind OpenAI o1, our first reasoning model.
Check OpenAI Deep Research’s demo video for more details.
A good framework solves a challenging problem.
Reinforcement Learning (RL) can be modelled as complex dataflow graph (Schaarschmidt et al. 2019; Liang et al. 2021; Sheng et al. 2025), consisting of:
LLM workloads often involves:
for prompts in dataloader:
# Stage 1: Generation
batch = actor.generate_sequences(prompts)
# Stage 2: Experience Preparation
batch = reward.compute_reward(batch)
batch = reference.compute_log_prob(batch)
batch = critic.compute_values(batch)
batch = compute_advantage(batch, "gae")
# Stage 3: Training
critic.update_critic(batch)
actor.update_actor(batch)
verl is efficient for intra-operator with the “multi-controller” paradigm and features like:
Parallelism Algorithms:
Efficient Kernels:
Training Backends:
Generation Backends:
verl is efficient for inter-operator with the “hybrid engine” paradigm, utilizing features like:
So far, verl has gained:
Many popular projects are built on top of verl, including:
Paradigm | Pro | Con |
---|---|---|
Single-Controller | Flexible | Communication Overhead |
Multi-Controller | Efficient | Complex Programming |
🤔 Which paradigm should we choose?
🤩 Actually, we can have “both”!
💡 Hybrid-Controller = Single-Controller + N x Multi-Controller
Each call in the single-controller (e.g. critic.compute_values
, actor.update_actor
) is an RPC to a multi-controller worker group.
for prompts in dataloader:
# Stage 1: Generation
batch = actor.generate_sequences(prompts)
# Stage 2: Experience Preparation
batch = reward.compute_reward(batch)
batch = reference.compute_log_prob(batch)
batch = critic.compute_values(batch)
batch = compute_advantage(batch, "gae")
# Stage 3: Training
critic.update_critic(batch)
actor.update_actor(batch)
class CriticWorker(3DParallelWorker):
@register(dispatch_mode=3D_PROTO)
def compute_values(self, batch: DataProto):
values = self.critic.forward(batch)
batch.update(values=values)
# ...
class ActorWorker(3DParallelWorker):
@register(dispatch_mode=3D_PROTO)
def update_actor(self, batch: DataProto):
loss = self.actor(batch)
loss.backward()
The register
decorator utility manages the distributed data transfer, which also makes multi-controller programming easier.
Welcome to discuss about / contribute to:
verl supports efficient RL training for huge MoE like DeepSeek-V3-671B, based on the following features:
GPTModel
For the most timely updates of important features, please keep an eye on verl’s Roadmap.
Welcome to join the verl community to discuss / contribute!
💻 Code Repository @ https://github.com/volcengine/verl
❓ Further Questions @ tongyuxuan361@bytedance.com
💼 We Are Recruiting! @ haibin.lin@bytedance.com
To enable this, use use_remove_padding
.
To enable this, use balance_batch
.
However, in gradient accumulation,
To resolve this, verl further supports to
To enable this, use use_dynamic_bsz
.
A canonical RL dataset in verl has the following fields:
prompt
: a list of messages {"role": "...", "content": "..."}
data_source
: used to choose the reward functionreward_model
: a dict containing
"ground_truth"
"style"
like "model"
or "rule"
extra_info
: a dict containing extra informationFor VLM RL, verl expects fields "images"
and/or "videos"
For examples, please check the examples/data_preprocess
.
You could also customize the field names via config. Please check the data
section in config files like ppo_trainer.yaml
for more details.
For further customization, verl provides the data.custom_cls
config,
The custom dataset class defined in the .py
file is required to accept the following initialization parameters:
verl allows to define custom reward function via the custom_reward_function
config:
The function defined in .py
should accept the parameters passed from the reward manager __call__
method. Taking NaiveRewardManager
as an example:
For more complex features, you can also add a new reward manager like PRIMERewardManager
or DAPORewardManager
.
To modify the loss function, the most convenient way is to
.backward()
callcompute_policy_loss
entropy_loss
For example, the DataParallelPPOActor.update_policy
method defines the loss function as follows:
class DataParallelPPOActor(BasePPOActor):
def update_policy(self, data: DataProto):
pg_loss = compute_policy_loss(
old_log_prob=old_log_prob, log_prob=log_prob,
advantages=advantages, # ...
)
entropy_loss = agg_loss(loss_mat=entropy)
policy_loss = pg_loss - entropy_loss * entropy_coeff
kld = kl_penalty(
logprob=log_prob, ref_logprob=ref_log_prob, # ...
)
kl_loss = agg_loss(loss_mat=kld)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
loss.backward()
As mentioned above, the main training logic is concentrated in the fit
function of the trainer classes like RayPPOTrainer
.
For example, the DAPORayTrainer
class overrides the fit
function to implement the “dynamic sampling” feature:
(See the next slide for the code ➡️)
class RayDAPOTrainer(RayPPOTrainer):
def fit(self):
for epoch in range(self.config.trainer.total_epochs):
batch = None
for batch_dict in self.train_dataloader:
new_batch = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
new_batch = new_batch.union(gen_batch_output)
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else:
# Getting `kept_traj_idxs` ...
new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
continue
else:
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]
# ...
verl: Flexible and Efficient RL for LLMs