A good framework solves an important problem.
Model | Large-Scale RL? | AIME 2024 | MATH 500 | GPQA Diamond | Code Forces |
---|---|---|---|---|---|
GPT-4o (OpenAI 2024) | ❌ | 44.6 | 60.3 | 50.6 | >11.0% |
o1 (OpenAI 2024) | ✅ | 74.4 | 94.8 | 77.3 | >89.0% |
R1 (DeepSeek-AI 2025) | ✅ | 79.8 | 97.3 | 71.5 | >96.3% |
OpenAI (2025):
Deep research independently discovers, reasons about, and consolidates insights from across the web.
To accomplish this, it was trained on real-world tasks requiring browser and Python tool use,
using the same reinforcement learning methods behind OpenAI o1, our first reasoning model.
Check OpenAI Deep Research’s demo video for more details.
A good framework solve a challenging problem.
Reinforcement Learning (RL) can be modelled as complex dataflow graph (Schaarschmidt et al. 2019; Liang et al. 2021; Sheng et al. 2025), consisting of:
LLM workloads often involves:
for prompts in dataloader:
# Stage 1: Generation
batch = actor.generate_sequences(prompts)
# Stage 2: Experience Preparation
batch = reward.compute_reward(batch)
batch = reference.compute_log_prob(batch)
batch = critic.compute_values(batch)
batch = compute_advantage(batch, "gae")
# Stage 3: Training
critic.update_critic(batch)
actor.update_actor(batch)
verl is efficient for intra-operator with the “multi-controller” paradigm and features like:
Parallelism Algorithms:
Efficient Kernels:
Training Backends:
Generation Backends:
verl is efficient for inter-operator with the “hybrid engine” paradigm, utilizing features like:
So far, verl has gained:
Paradigm | Pro | Con |
---|---|---|
Single-Controller | Flexible | Communication Overhead |
Multi-Controller | Efficient | Complex Programming |
🤔 Which paradigm should we choose?
🤩 Actually, we can have “both”!
💡 Hybrid-Controller = Single-Controller + N x Multi-Controller
Each call in the single-controller (e.g. critic.compute_values
, actor.update_actor
) is an RPC to a multi-controller worker group.
for prompts in dataloader:
# Stage 1: Generation
batch = actor.generate_sequences(prompts)
# Stage 2: Experience Preparation
batch = reward.compute_reward(batch)
batch = reference.compute_log_prob(batch)
batch = critic.compute_values(batch)
batch = compute_advantage(batch, "gae")
# Stage 3: Training
critic.update_critic(batch)
actor.update_actor(batch)
class CriticWorker(3DParallelWorker):
@register(dispatch_mode=3D_PROTO)
def compute_values(self, batch: DataProto):
values = self.critic.forward(batch)
batch.update(values=values)
# ...
class ActorWorker(3DParallelWorker):
@register(dispatch_mode=3D_PROTO)
def update_actor(self, batch: DataProto):
loss = self.actor(batch)
loss.backward()
The register
decorator utility manages the distributed data transfer, which also makes multi-controller programming easier.
"images"
& "videos"
fields in datasetBaseTool
Welcome to discuss about / contribute to:
verl is working on supporting efficient RL training for huge MoE like DeepSeek-V3-671B, based on the following features:
GPTModel
For more details, please check our tracker #708.
We have also received many meaningful feature requests from the community, e.g.,
For the most timely updates of important features, please keep an eye on verl’s Roadmap.
Welcome to join the verl community to discuss / contribute!
💻 Code Repository @ https://github.com/volcengine/verl
❓ Further Questions @ tongyuxuan361@gmail.com
💼 We Are Recruiting! @ haibin.lin@bytedance.com
To enable this, use use_remove_padding
.
To enable this, use balance_batch
.
However, in gradient accumulation,
To resolve this, verl further supports to
To enable this, use use_dynamic_bsz
.
A canonical RL dataset in verl has the following fields:
prompt
: a list of messages {"role": "...", "content": "..."}
data_source
: used to choose the reward functionreward_model
: a dict containing
"ground_truth"
"style"
like "model"
or "rule"
extra_info
: a dict containing extra informationFor VLM RL, verl expects fields "images"
and/or "videos"
For examples, please check the examples/data_preprocess
.
You could also customize the field names via config. Please check the data
section in config files like ppo_trainer.yaml
for more details.
For further customization, verl provides the data.custom_cls
config,
The custom dataset class defined in the .py
file is required to accept the following initialization parameters:
verl allows to define custom reward function via the custom_reward_function
config:
The function defined in .py
should accept the parameters passed from the reward manager __call__
method. Taking NaiveRewardManager
as an example:
For more complex features, you can also add a new reward manager like PRIMERewardManager
or DAPORewardManager
.
To modify the loss function, the most convenient way is to
.backward()
callcompute_policy_loss
entropy_loss
For example, the DataParallelPPOActor.update_policy
method defines the loss function as follows:
class DataParallelPPOActor(BasePPOActor):
def update_policy(self, data: DataProto):
pg_loss = compute_policy_loss(
old_log_prob=old_log_prob, log_prob=log_prob,
advantages=advantages, # ...
)
entropy_loss = agg_loss(loss_mat=entropy)
policy_loss = pg_loss - entropy_loss * entropy_coeff
kld = kl_penalty(
logprob=log_prob, ref_logprob=ref_log_prob, # ...
)
kl_loss = agg_loss(loss_mat=kld)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
loss.backward()
As mentioned above, the main training logic is concentrated in the fit
function of the trainer classes like RayPPOTrainer
.
For example, the DAPORayTrainer
class overrides the fit
function to implement the “dynamic sampling” feature:
(See the next slide for the code ➡️)
class RayDAPOTrainer(RayPPOTrainer):
def fit(self):
for epoch in range(self.config.trainer.total_epochs):
batch = None
for batch_dict in self.train_dataloader:
new_batch = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
new_batch = new_batch.union(gen_batch_output)
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else:
# Getting `kept_traj_idxs` ...
new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
continue
else:
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]
# ...
verl: Flexible and Efficient RL for LLMs