verl 框架:3. 加载数据与创建 batch

run_ppo()函数完成 Ray 集群初始化之后,创建一个TaskRunner将整个 RL 训练流程封装在一个独立的 Ray Actor 中,提交到远程执行,以支持分布式调度。

TaskRunner.run运行在远程 Ray Actor 中,流程如下(重点关注 PPO 训练器RayPPOTrainer中的 dataset 来源):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def run(self, config):
# 0. 注册各种 Worker 角色,初始化资源池管理器
......

from verl.utils.dataset.rl_dataset import collate_fn
# Create training and validation datasets.
train_dataset = create_rl_dataset(
config.data.train_files,
config.data,
tokenizer,
processor,
is_train=True,
max_samples=config.data.get("train_max_samples", -1),
)
val_dataset = create_rl_dataset(
config.data.val_files,
config.data,
tokenizer,
processor,
is_train=False,
max_samples=config.data.get("val_max_samples", -1),
)

# Initialize the PPO trainer.
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=self.role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
)

# Initialize the workers of the trainer.
trainer.init_workers()

# Start the training process.
trainer.fit()

创建 RLHFDataset

RayPPOTrainer传入的train_dataset, val_dataset通过调用create_rl_dataset生成:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True):
"""Create a dataset.

Arguments:
data_paths: List of paths to data files.
data_config: The data config.
tokenizer (Tokenizer): The tokenizer.
processor (Processor): The processor.

Returns:
dataset (Dataset): The dataset.
"""
from torch.utils.data import Dataset

from verl.utils.dataset.rl_dataset import RLHFDataset

# 1. 检查是否指定了自定义数据集类
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
# Dynamically load the custom dataset class
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
# Verify that the custom dataset class inherits from torch.utils.data.Dataset
if not issubclass(dataset_cls, Dataset):
raise TypeError(......)
# 2. 检查是否指定了数据生成策略
elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train:
# If a data generation strategy is specified, use the DynamicGenDataset class
from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset

dataset_cls = DynamicGenDataset
else:
# 3. 使用默认的 RLHFDataset
dataset_cls = RLHFDataset

# Instantiate the dataset using the determined dataset class
dataset = dataset_cls(
data_files=data_paths,
tokenizer=tokenizer,
processor=processor,
config=data_config,
)

return dataset

创建了一个RLHFDataset实例,具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class RLHFDataset(Dataset):
"""
Load and preprocess RLHF data from Parquet files.

- Caches files locally.
- Reads into a HuggingFace Dataset and tokenizes prompts.
- Optionally handles images/videos via a ProcessorMixin.
- Filters prompts over a max length.
- Supports resuming from checkpoints.

Args:
data_files (str or list): Path(s) to Parquet file(s).
tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
"""
# 缓存配置:_download 调用 copy_to_local 将远程文件复制到本地缓存
def __init__(
self,
data_files: str | list[str],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,
):
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))

......
self._download() # 下载远程文件到本地缓存
self._read_files_and_tokenize() # 读取 Parquet 文件并进行预处理

def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
# 训练集:使用 Hugging Face datasets 加载 Parquet
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframes.append(dataframe)
# 合并所有数据集
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
# 过滤过长的提示
self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)

def maybe_filter_out_long_prompts(self, dataframe): # 计算每个样本的 token 长度并过滤
if self.filter_overlong_prompts:
# 定义长度计算函数
def doc2len(doc) -> int:
if processor is not None:
# 多模态处理
messages = self._build_messages(doc)
raw_prompt = self.processor.apply_chat_template(...)
images = [process_image(image) for image in doc[image_key]]
return len(processor(text=[raw_prompt], images=images)["input_ids"][0])
else:
# 纯文本处理
return len(tokenizer.apply_chat_template(doc[prompt_key], ...))

# 并行过滤
dataframe = dataframe.filter(
lambda doc: doc2len(doc) <= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
)

return dataframe

包含以下功能:

  1. 支持从远程存储下载 Parquet 文件到本地缓存,支持共享内存加速文件访问;
  2. 支持多进程并行过滤过长的 prompts(通过doc2len可配置过滤策略);
  3. 支持纯文本、图像和视频的多模态输入,解析 <image><video> 标签,将多模态内容转换为结构化格式;
  4. 添加 chat template 格式化对话,将文本转换为 token IDs,生成 attn mask 和 position ids;
  5. padding 到指定长度,支持多种截断策略(left, right, middle, error),生成位置编码。

第3, 4, 5条是何时实现的呢?遍历单个样本调用RLHFDataset.__getitem__() 时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __getitem__(self, item):
row_dict: dict = self.dataframe[item]
messages = self._build_messages(row_dict) # 构建消息格式
model_inputs = {}

if self.processor is not None:
......
else: # 纯文本处理
# 1. 验证和应用 Chat Template
if self.apply_chat_template_kwargs.get("chat_template") is None:
# 确保 tokenizer 有可用的 chat template(支持通过默认参数覆盖)
assert hasattr(self.tokenizer, "chat_template"), (
"chat_template should be provided in apply_chat_template_kwargs or tokenizer config, "
"models like GLM can copy chat_template.jinja from instruct models"
)
raw_prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
)
# 2. Tokenization: 将文本转换为 token IDs; 生成 attention mask
model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
""" 输出:
model_inputs = {
"input_ids": tensor([[1, 2, 3, 4, 5, ...]]), # shape: [1, seq_len]
"attention_mask": tensor([[1, 1, 1, 1, 1, ...]]) # shape: [1, seq_len]
}
"""
# 3. 提取基础信息:input_ids, attention_mask
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")
# 4. 后处理
input_ids, attention_mask = verl_F.postprocess_data(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
# 5. 计算位置编码
position_ids = compute_position_id_with_mask(attention_mask)
# 6. 填充 row_dict
row_dict["input_ids"] = input_ids[0]
row_dict["attention_mask"] = attention_mask[0]
row_dict["position_ids"] = position_ids[0]
# ... 其他字段

return row_dict

对比 row_dictRLHFDataset.__getitem__() 处理前后的差异:

  1. 输入:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    row_dict = {
    "data_source": "hiyouga/geometry3k",
    "prompt": [{"role": "user", "content": "问题内容"}],
    "images": [图像数据],
    "ability": "math",
    "reward_model": {"style": "rule", "ground_truth": "答案"},
    "extra_info": {
    "split": "train",
    "index": 0,
    "answer": "答案",
    "question": "问题",
    }
    }
  2. 输出:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    row_dict = {
    # 原始字段
    "data_source": "hiyouga/geometry3k",
    "ability": "math",
    "reward_model": {"style": "rule", "ground_truth": "答案"},
    "extra_info": {...},

    # 新增的 tensor 字段
    "input_ids": torch.Tensor, # shape: [seq_len]
    "attention_mask": torch.Tensor, # shape: [seq_len]
    "position_ids": torch.Tensor, # shape: 纯文本模型[seq_len]; 多模态模型[4, seq_len] (Qwen2VL/GLM4V)
    "raw_prompt_ids": List[int], # 原始 prompt 的 token IDs

    # 多模态字段(如果有)
    "multi_modal_data": {
    "image": [处理后的图像数据],
    "video": [处理后的视频数据]
    },
    "multi_modal_inputs": {...}, # 多模态输入

    # 其他字段
    "index": int, # 样本索引
    "tools_kwargs": dict, # 工具参数
    "interaction_kwargs": dict, # 交互参数
    }

输出包含长度均为 [seq_len]input_ids, attention_mask, position_ids(针对纯文本,多模态时格式不同), raw_prompt_ids的 Tensor 字段。

pad 和 truncate 在针对 input ids 和 attn mask 的后处理计算postprocess_data函数中完成,这里引入的是 verl.utils.torch_functional 中相关函数(通过 import ... as verl_F 引入),用以处理 tokenizer 的输出,并 pad/truncate 到恒定长度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def postprocess_data(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_length: int,
pad_token_id: int,
left_pad=True,
truncation="error",
):
"""Process tokenizer outputs to consistent shapes via padding/truncation.

Args:
input_ids: Token indices [batch_size, seq_len]
attention_mask: Mask [batch_size, seq_len]
max_length: Target sequence length
pad_token_id: Padding token ID
left_pad: Pad left if True
truncation: "left", "right", "middle" or "error"

Returns:
(input_ids, attention_mask) padded/truncated to max_length
"""
assert truncation in ["left", "right", "middle", "error"]
assert input_ids.ndim == 2

sequence_length = input_ids.shape[-1]
if sequence_length < max_length:
# 填充处理
input_ids = pad_sequence_to_length(input_ids, max_length, pad_token_id, left_pad)
attention_mask = pad_sequence_to_length(attention_mask, max_length, 0, left_pad)

elif sequence_length > max_length:
# 截断处理
if truncation == "left":
input_ids = input_ids[:, -max_length:]
attention_mask = attention_mask[:, -max_length:]
elif truncation == "right":
input_ids = input_ids[:, :max_length]
attention_mask = attention_mask[:, :max_length]
elif truncation == "middle":
left_half = max_length // 2
right_half = max_length - left_half
input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1)
attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1)
elif truncation == "error":
raise NotImplementedError(f"Sequence length {sequence_length} > {max_length}")

return input_ids, attention_mask

# 1. 填充处理:
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
if left_pad:
pad_tuple = (max_seq_len - tensors.shape[-1], 0) # 左填充
else:
pad_tuple = (0, max_seq_len - tensors.shape[-1]) # 右填充

return F.pad(tensors, pad_tuple, "constant", pad_token_id)

注意RLHFDataset.__getitem__()中位置编码的计算:提供了标准文本模型、QWen2VL、GLM4 的模型的位置编码(后两者在 verl/model/transformer)中提供。

  1. 标准文本模型:通过 compute_position_id_with_mask 函数基于 attn mask 计算位置编码:只对有效 token 分配连续的位置 ID;填充位置的位置 ID 为 0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
# 单样本示例:
# 输入掩码:前4个是有效token,后3个是填充
mask = torch.tensor([1, 1, 1, 1, 0, 0, 0])
# 步骤1:累积求和
cumsum_result = torch.cumsum(mask, dim=-1)
print("累积求和:", cumsum_result)
# 输出: tensor([1, 2, 3, 4, 4, 4, 4])
# 步骤2:减1
minus_one = cumsum_result - 1
print("减1后:", minus_one)
# 输出: tensor([0, 1, 2, 3, 3, 3, 3])
# 步骤3:裁剪
position_ids = torch.clip(minus_one, min=0, max=None)
print("最终位置编码:", position_ids)
# 输出: tensor([0, 1, 2, 3, 0, 0, 0])

  1. QWen2VL 使用 3D RoPE 位置编码实现多模态数据的统一处理。官方实现:https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405

官方支持批量导入,同时返回位置编码和增量;输出为:

  • position_ids: (3, batch_size, sequence_length) - 3D 位置编码
  • mrope_position_deltas: (batch_size, 1) - 位置编码增量:计算位置编码的最大值与实际序列长度的差值,用于后续的位置编码调整。

打包为 batch

RayPPOTrainer传入的 collate_fn 通过from verl.utils.dataset.rl_dataset import collate_fn引入,该函数的作用是将若干数据聚集成一个 batch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# verl/utils/dataset/rl_dataset.py
def collate_fn(data_list: list[dict]) -> dict:
"""
Collate a batch of sample dicts into batched tensors and arrays.

Args:
data_list: List of dicts mapping feature names to torch.Tensor or other values.

Returns:
Dict where tensor entries are stacked into a torch.Tensor of shape
(batch_size, *dims) and non-tensor entries are converted to
np.ndarray of dtype object with shape (batch_size,).
"""
tensors = defaultdict(list)
non_tensors = defaultdict(list)

# 1. 分离张量和非张量数据
for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key].append(val)
else:
non_tensors[key].append(val)

# 2. 堆叠张量数据:从 [seq_len, ...] 转为 [batch_size, seq_len, ...]
for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)

# 3. 转换非张量数据为 numpy 数组
for key, val in non_tensors.items():
non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))

return {**tensors, **non_tensors}

对比collate_fn打包前后的输入和输出:

  1. 输入:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    # 单个样本示例
    sample = {
    # === 张量字段 (torch.Tensor) ===
    "input_ids": torch.tensor([1, 2, 3, 0, 0]), # shape: (seq_len,)
    "attention_mask": torch.tensor([1, 1, 1, 0, 0]), # shape: (seq_len,)
    "position_ids": torch.tensor([0, 1, 2, 0, 0]), # shape: (seq_len,) 或 (4, seq_len) 多模态
    "raw_prompt_ids": [1, 2, 3], # List[int]

    # === 多模态字段 (可选) ===
    "multi_modal_data": { # dict
    "image": [processed_image_tensor], # List[torch.Tensor]
    "video": [processed_video_tensor] # List[torch.Tensor]
    },
    "multi_modal_inputs": { # dict (可选)
    "image_grid_thw": torch.tensor([[1, 2, 2]]), # shape: (1, 3)
    "video_grid_thw": torch.tensor([[2, 2, 2]]), # shape: (1, 3)
    "second_per_grid_ts": torch.tensor([1.0]) # shape: (1,)
    },

    # === 非张量字段 (各种类型) ===
    "data_source": "hiyouga/geometry3k", # str
    "ability": "math", # str
    "reward_model": { # dict
    "style": "rule",
    "ground_truth": "42"
    },
    "extra_info": { # dict
    "split": "train",
    "index": 0,
    "answer": "42",
    "question": "What is 6*7?"
    },
    "index": 0, # int
    "tools_kwargs": {}, # dict
    "interaction_kwargs": {}, # dict

    # === 可选字段 ===
    "raw_prompt": [{"role": "user", "content": "..."}], # List[dict] (可选)
    "full_prompts": "<|im_start|>user\n...", # str (可选)
    }
  2. 输出:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    batch_dict = {
    # === 张量字段 ===
    "input_ids": torch.tensor([
    [1, 2, 3, 0, 0], # 样本1
    [4, 5, 6, 7, 0] # 样本2
    ]), # shape: (batch_size, seq_len)

    "attention_mask": torch.tensor([
    [1, 1, 1, 0, 0], # 样本1
    [1, 1, 1, 1, 0] # 样本2
    ]), # shape: (batch_size, seq_len)

    "position_ids": torch.tensor([
    [0, 1, 2, 0, 0], # 样本1
    [0, 1, 2, 3, 0] # 样本2
    ]), # shape: (batch_size, seq_len)

    # === 非张量字段 ===
    "raw_prompt_ids": np.array([
    [1, 2, 3], # 样本1
    [4, 5, 6, 7] # 样本2
    ], dtype=object), # shape: (batch_size,)

    "data_source": np.array([
    "geo3k", # 样本1
    "geo3k" # 样本2
    ], dtype=object), # shape: (batch_size,)

    "ability": np.array([
    "math", # 样本1
    "math" # 样本2
    ], dtype=object), # shape: (batch_size,)

    "reward_model": np.array([
    {"style": "rule", "ground_truth": "42"}, # 样本1
    {"style": "rule", "ground_truth": "56"} # 样本2
    ], dtype=object), # shape: (batch_size,)

    "extra_info": np.array([
    {"split": "train", "index": 0}, # 样本1
    {"split": "train", "index": 1} # 样本2
    ], dtype=object), # shape: (batch_size,)

    "index": np.array([0, 1], dtype=object), # shape: (batch_size,)
    "tools_kwargs": np.array([{}, {}], dtype=object), # shape: (batch_size,)
    "interaction_kwargs": np.array([{}, {}], dtype=object) # shape: (batch_size,)
    }

RayPPOTrainer.fit()

回到RayPPOTrainer.fit().

初始化:

  1. 创建 Tracking 日志记录器:使用 Tracking 类初始化,配置项目名、实验名、日志后端(如 wandb、tensorboard),将完整配置(OmegaConf.to_container)记录到日志系统;
  2. 初始化全局步数:self.global_steps = 0
  3. 加载检查点以恢复模型状态、优化器状态和训练进度;
  4. 训练前验证(可选);
  5. 设置 Rollout Skip(可选):如果 skip_rollout=True,使用 RolloutSkip 包装 generate_sequences,跳过实际生成;
  6. 使用 tqdm 创建进度条,显示训练进度,并设置初始步数;
  7. 初始化性能分析状态;

主训练外层循环(Epoch 迭代)+内层循环(Batch 迭代):

1
2
3
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
...

对于内层循环中的每个 batch:

  1. 调用 _start_profiling() 启动所有 WorkerGroup 的性能分析;

  2. 准备数据:将 batch_dict 转换为 DataProto,同时为每个样本生成唯一的 UUID;

  3. 提取用于 rollout 的数据(移除 input_ids, attention_mask, position_ids等;只保留 data_source, reward_model, extra_info, uid 等字段);

  4. 根据 async_rollout_mode 选择同步/异步的 rollout 方式并生成序列,记录生成时间;

  5. 处理 REMAX 基线(可选):生成确定性基线序列,计算基线奖励,用于 REMAX 优势估计器;

  6. 为每个样本分配唯一 ID,重复数据以对齐多次采样,计算响应掩码response_mask(用于区分实际生成部分和 padding),并可选地进行批次平衡;

  7. batch 平衡(可选):重新排序数据使每个 dp rank 的总 token 数均衡(不影响基于 uid 的 advantage 计算,可能影响 mini-batch 的 loss 计算);

  8. 根据配置使用奖励模型或自定义奖励函数计算 token 级别的奖励分数,支持同步和异步计算;

  9. 使用 megatron 基于训练开始前的 policy 重新计算 behaviour policy 的 log probabilities,用于重要性采样,同时计算熵值;

  10. 使用 reference policy 计算 log probs,用于 KL 散度计算;

  11. 使用 Critic 网络计算状态价值,用于优势函数估计;

  12. 根据配置的优势估计器(GAE、GRPO、REMAX 等)计算优势函数,支持 KL 惩罚;

  13. 使用计算出的优势函数更新 Critic 网络参数;

  14. 在 Critic 预热完成后,使用 PPO 损失函数更新 Actor 网络参数;

  15. 将生成的序列、输入、输出和分数保存到指定目录;

  16. 根据配置的频率执行验证,计算验证指标并记录;

  17. 根据配置的频率保存模型检查点;

  18. 收集训练指标、时序指标和吞吐量指标,并记录到日志系统;

  19. 更新进度条,递增全局步数,并在达到总训练步数时结束训练;

  20. 根据配置在特定步数启用/禁用性能分析,用于调试和优化。

  21. 使用 tqdm 创建进度条,显示训练进度,并设置初始步数;

  22. 遍历配置的 total_epochs 数和 train_dataloader,每个 train_batch 完成多步更新;

  23. 从 batch 中分离出不用于 rollout 的数据(input_ids, attention_mask, position_ids 等),保留其他数据用于后续处理;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf

from verl.utils.tracking import Tracking

logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)

self.global_steps = 0

# load checkpoint before doing anything
self._load_checkpoint()

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return

if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
rollout_skip.wrap_generate_sequences()

# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

# we start from step 1
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0

prev_step_profile = False
curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
next_step_profile = False

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}

with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
batch: DataProto = DataProto.from_single_dict(batch_dict)

# add uid to batch
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)

gen_batch = self._get_gen_batch(batch)

# pass global_steps to trace
gen_batch.meta_info["global_steps"] = self.global_steps
gen_batch_output = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)

is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)

timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
if self.reward_fn is None:
raise ValueError("A reward_fn is required for REMAX advantage estimation.")

with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
if not self.async_rollout_mode:
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
else:
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
# compute reward model score on batch
rm_scores = None
if self.use_rm and "rm_scores" not in batch.batch.keys():
rm_scores = self.rm_wg.compute_rm_score(batch)
batch = batch.union(rm_scores)
reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)

keys_to_pop = set(gen_baseline_output.batch.keys())
if rm_scores is not None:
keys_to_pop.update(rm_scores.batch.keys())
batch.pop(batch_keys=list(keys_to_pop))

batch.batch["reward_baselines"] = reward_baseline_tensor

del rm_scores, gen_baseline_batch, gen_baseline_output
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if "response_mask" not in batch.batch.keys():
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)

# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
if self.use_rm and "rm_scores" not in batch.batch.keys():
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(
data=batch, config=self.config, tokenizer=self.tokenizer
)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

# recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)

if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
from verl.utils.debug.metrics import calculate_debug_metrics

metrics.update(calculate_debug_metrics(batch))

if self.use_reference_policy:
# compute reference log_prob
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

# Compute rollout importance sampling weights centrally (once per batch)
# This corrects for mismatch between rollout policy and training policy
# Also computes mismatch metrics (KL, PPL, etc.)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)

# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)

# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)

# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)

# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)

# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
# Check if the conditions for saving a checkpoint are met.
# The conditions include a mandatory condition (1) and
# one of the following optional conditions (2/3/4):
# 1. The save frequency is set to a positive value.
# 2. It's the last training step.
# 3. The current step number is a multiple of the save frequency.
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()

with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
self._stop_profiling(
curr_step_profile and not next_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
prev_step_profile = curr_step_profile
curr_step_profile = next_step_profile

steps_duration = timing_raw["step"]
self.max_steps_duration = max(self.max_steps_duration, steps_duration)

# training metrics
metrics.update(
{
"training/global_step": self.global_steps,
"training/epoch": epoch,
}
)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation

# this is experimental and may be changed/removed in the future in favor of a general-purpose one
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
self.train_dataloader.sampler.update(batch=batch)

# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

progress_bar.update(1)
self.global_steps += 1

if (
hasattr(self.config.actor_rollout_ref.actor, "profiler")
and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
):
self.actor_rollout_wg.dump_memory_snapshot(
tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
)

if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return

# this is experimental and may be changed/removed in the future
# in favor of a general-purpose data buffer pool
if hasattr(self.train_dataset, "on_batch_end"):
# The dataset may be changed after each training batch
self.train_dataset.on_batch_end(batch=batch)

创建 DataLoader:RayPPOTrainer.create_dataloader self.train_dataloader的来源是什么呢?由RayPPOTrainer._init_ 调用 RayPPOTrainer._create_dataloader创建:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# verl/trainer/ppo/ray_trainer.py
class RayPPOTrainer:
def __init__(
......
collate_fn=None, # 将 data samples 聚合为 batch
train_sampler: Optional[Sampler] = None,
):
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset, # RLHFDataset 实例
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=num_workers, # 并行加载进程数
drop_last=True, # 丢弃最后一个不完整的 batch
collate_fn=collate_fn, # 批处理函数
sampler=train_sampler, # 采样器
)

RayPPOTrainer.fit() 的训练循环中:for batch_dict in self.train_dataloader

1
2
3
def fit(self):
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:

以上 batch_dict 来自哪儿呢?

RayPPOTrainer.__init__ -> RayPPOTrainer._create_dataloader -> StatefulDataLoader.__iter__() -> DataLoader.__iter__()(Pytorch) -> 遍历数据集for i in range(len(dataset)): -> 获取单个样本 dataset[i] -> RLHFDataset.__getitem__(i) -> 打包成 batch collate_fn([sample1, sample2, ..., sampleN])

得到原始的 batch:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
batch_dict = {
# === 张量字段 ===
"input_ids": torch.tensor([
[1, 2, 3, 0, 0], # 样本1
[4, 5, 6, 7, 0] # 样本2
]), # shape: (batch_size, seq_len)

"attention_mask": torch.tensor([
[1, 1, 1, 0, 0], # 样本1
[1, 1, 1, 1, 0] # 样本2
]), # shape: (batch_size, seq_len)

"position_ids": torch.tensor([
[0, 1, 2, 0, 0], # 样本1
[0, 1, 2, 3, 0] # 样本2
]), # shape: (batch_size, seq_len)

# === 非张量字段 ===
......

"data_source": np.array([
"geo3k", # 样本1
"geo3k" # 样本2
], dtype=object), # shape: (batch_size,)
"reward_model": np.array([
{"style": "rule", "ground_truth": "42"}, # 样本1
{"style": "rule", "ground_truth": "56"} # 样本2
], dtype=object), # shape: (batch_size,)

"extra_info": np.array([
{"split": "train", "index": 0}, # 样本1
{"split": "train", "index": 1} # 样本2
], dtype=object), # shape: (batch_size,)
}

DataFlow

RayPPOTrainer.fit()的数据流具体如下:

  1. parquet 文件

    1
    data_files = "~/data/rlhf/gsm8k/train.parquet"
  2. RLHFDataset

    1
    2
    3
    4
    5
    6
    dataset = RLHFDataset(
    data_files=data_paths,
    tokenizer=tokenizer,
    processor=processor,
    config=data_config,
    )
  3. DataLoader

    1
    2
    3
    4
    5
    6
    7
    8
    self.train_dataloader = StatefulDataLoader(
    dataset=self.train_dataset, # RLHFDataset 实例
    batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
    num_workers=num_workers, # 并行加载进程数
    drop_last=True, # 丢弃最后一个不完整的 batch
    collate_fn=collate_fn, # 批处理函数
    sampler=train_sampler, # 采样器
    )
  4. DataProto

    1
    2
    3
    for batch_dict in self.train_dataloader:
    ......
    batch: DataProto = DataProto.from_single_dict(batch_dict)
  5. pop 提取用于 rollout 的数据:

    1
    gen_batch = self._get_gen_batch(batch)
  6. 通过 rollout 生成序列:

    1
    2
    3
    4
    gen_batch_output = gen_batch.repeat(
    repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
    )
    gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
  7. 将生成的数据gen_batch合并到 batch:

    1
    batch = batch.union(gen_batch_output)
  8. 计算 reward:

    1
    2
    reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
    batch.batch["token_level_scores"] = reward_tensor
  9. 计算 advantage:

    1
    2
    3
    4
    5
    batch = compute_advantage(
    batch,
    adv_estimator=self.config.algorithm.adv_estimator,
    ......
    )
  10. 重新计算log_probs

    1
    2
    old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
    batch = batch.union(old_log_prob)
  11. 计算 reference model 的 log_probs:

    1
    2
    3
    if self.use_reference_policy:
    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
    batch = batch.union(ref_log_prob)
  12. 计算 value function:

    1
    2
    3
    if self.use_critic:
    values = self.critic_wg.compute_values(batch)
    batch = batch.union(values)
  13. 更新 critic:

    1
    2
    3
    4
    if self.use_critic:
    critic_output = self.critic_wg.update_critic(batch)
    critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
    metrics.update(critic_output_metrics)
  14. 更新 actor:

    1
    actor_output = self.actor_rollout_wg.update_actor(batch)
  15. 返回训练指标:

    1
    2
    3
    actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
    metrics.update(actor_output_metrics)
    logger.log(data=metrics, step=self.global_steps)

除了最初的三步,后续步骤均通过 DataProto 完成数据交换。

参考

深入浅出理解 verl 源码 part 2