Skip to content

Renaming and Organization of RL algorithms in preparation for Development #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Jun 16, 2025
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2eadf4b
reorg
jdchang1 Jun 5, 2025
8e0b986
fix scripts import
jdchang1 Jun 5, 2025
6b0b949
fix imports in tests
jdchang1 Jun 5, 2025
796b890
fix reward manager imports
jdchang1 Jun 5, 2025
4b22915
import fix
jdchang1 Jun 5, 2025
3ee2549
backwards compat in pyproject
jdchang1 Jun 5, 2025
7a5b596
style lint
jdchang1 Jun 5, 2025
77d018c
fix registry test
jdchang1 Jun 5, 2025
d157289
update test names to be aligned with new names
jdchang1 Jun 5, 2025
129c2ee
update online rl loss types to Enum
jdchang1 Jun 5, 2025
9a18867
fix model init test for ppo
jdchang1 Jun 5, 2025
ecd7119
for backwards compatibility
jdchang1 Jun 5, 2025
96d2eb4
Merge branch 'main' into reorg
jdchang1 Jun 5, 2025
9b3a963
style fix
jdchang1 Jun 5, 2025
094435d
unified naming scheme across alg types
jdchang1 Jun 5, 2025
19007de
fix model name updates
jdchang1 Jun 5, 2025
b1d3973
naming fix
jdchang1 Jun 5, 2025
bd3ea7e
Merge branch 'main' into reorg
jdchang1 Jun 13, 2025
a9231c5
unified to enums for loss type
jdchang1 Jun 13, 2025
11d38ae
fix test
jdchang1 Jun 13, 2025
39c6761
fix test
jdchang1 Jun 13, 2025
2002b6f
style
jdchang1 Jun 13, 2025
13d87d1
organize PG loss
jdchang1 Jun 13, 2025
c342130
more org
jdchang1 Jun 13, 2025
23e1727
lint
jdchang1 Jun 13, 2025
78bb625
backwards compatibility with orl_eval
jdchang1 Jun 13, 2025
a86d5af
orl backwards compat
jdchang1 Jun 13, 2025
af98a95
enum compatibility
jdchang1 Jun 13, 2025
e4d60b2
isort
jdchang1 Jun 13, 2025
80cbded
resolve circular import
jdchang1 Jun 13, 2025
7ab6f27
fixed my dumb mistake
jdchang1 Jun 13, 2025
d727f55
loss fix
jdchang1 Jun 13, 2025
a2873e1
generation reorg
jdchang1 Jun 16, 2025
13b7013
pre comit
jdchang1 Jun 16, 2025
2de97d2
add enum to generation utils
jdchang1 Jun 16, 2025
4c4595c
remove init
jdchang1 Jun 16, 2025
422fbfc
update yaml
jdchang1 Jun 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions compose_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -10,14 +10,11 @@
'When installing plugins, please use one of the extras depending on which version of llmfoundry you are using.',
)

import compose_rl.dpo as dpo
import compose_rl.reward_learning as reward_learning
from compose_rl import data, metrics, utils
from compose_rl import algorithms, data, metrics, utils

__all__ = [
'algorithms',
'utils',
'data',
'dpo',
'reward_learning',
'metrics',
]
12 changes: 12 additions & 0 deletions compose_rl/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

import compose_rl.algorithms.offline as offline
import compose_rl.algorithms.online as online
import compose_rl.algorithms.reward_modeling as reward_modeling

__all__ = [
'offline',
'online',
'reward_modeling',
]
14 changes: 14 additions & 0 deletions compose_rl/algorithms/offline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.algorithms.offline.callback import ReferencePolicyCallback
from compose_rl.algorithms.offline.model import (
ComposerHFPairwiseOfflinePolicyLM,
ComposerMPTPairwiseOfflinePolicyLM,
)

__all__ = [
'ComposerMPTPairwiseOfflinePolicyLM',
'ComposerHFPairwiseOfflinePolicyLM',
'ReferencePolicyCallback',
]
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@
from llmfoundry.utils.config_utils import process_init_device # type: ignore


class DPOCallback(CallbackWithConfig):
"""Callback to run DPO in an offline RL setting.
class ReferencePolicyCallback(CallbackWithConfig):
"""Callback to run reference policy in offline RL.
Args:
train_config (dict): Training config passed to callback via foundry train.py as
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""DPO Composer Implementation."""
"""Pairwise Offline RL Composer Implementation."""

from __future__ import annotations

@@ -13,14 +13,18 @@
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.modeling_outputs import CausalLMOutputWithPast

from compose_rl.dpo.model_methods import DPOEnum, dpo_forward, dpo_loss
from compose_rl.algorithms.offline.model_methods import (
PairwiseOfflineEnum,
pairwise_offline_forward,
pairwise_offline_loss,
)

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

log = logging.getLogger(__name__)


class ComposerDPOLM(ComposerMPTCausalLM):
class ComposerMPTPairwiseOfflinePolicyLM(ComposerMPTCausalLM):
"""MPT model wrapper for DPO model."""

def __init__(
@@ -32,7 +36,7 @@ def __init__(
average_log_prob: bool = False,
**kwargs: Any,
):
self.loss_type = DPOEnum(loss_type)
self.loss_type = PairwiseOfflineEnum(loss_type)
self.beta = beta
self.label_smoothing = label_smoothing
self.sft_alpha = sft_alpha
@@ -43,7 +47,7 @@ def __init__(

def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
assert self.tokenizer is not None
return dpo_forward(
return pairwise_offline_forward(
model=self.model,
tokenizer=self.tokenizer,
batch=batch,
@@ -62,7 +66,7 @@ def eval_forward(

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> dict[str, torch.Tensor]:
return dpo_loss(
return pairwise_offline_loss(
outputs,
batch,
self.loss_type,
@@ -72,7 +76,7 @@ def loss(self, outputs: CausalLMOutputWithPast,
)


class ComposerHFDPOLM(ComposerHFCausalLM):
class ComposerHFPairwiseOfflinePolicyLM(ComposerHFCausalLM):
"""HF class wrapper for DPO model."""

def __init__(
@@ -84,7 +88,7 @@ def __init__(
average_log_prob: bool = False,
**kwargs: Any,
):
self.loss_type = DPOEnum(loss_type)
self.loss_type = PairwiseOfflineEnum(loss_type)
self.beta = beta
self.label_smoothing = label_smoothing
self.sft_alpha = sft_alpha
@@ -95,7 +99,7 @@ def __init__(

def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
assert self.tokenizer is not None
return dpo_forward(
return pairwise_offline_forward(
model=self.model,
tokenizer=self.tokenizer,
batch=batch,
@@ -111,7 +115,7 @@ def eval_forward(

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> dict[str, torch.Tensor]:
return dpo_loss(
return pairwise_offline_loss(
outputs,
batch,
self.loss_type,
318 changes: 318 additions & 0 deletions compose_rl/algorithms/offline/model_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""DPO Utils."""

from enum import Enum
from typing import Mapping, MutableMapping, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

from compose_rl.utils import (
clear_mb_load_balancing_loss,
extract_packed_chosen_rejected,
get_batch_logp,
get_mb_load_balancing_loss,
)


class PairwiseOfflineEnum(Enum):
DPO = 'dpo'
RPO = 'rpo'
RCDPO = 'rcdpo'
REBEL = 'rebel'
IPO = 'ipo'
KTO = 'kto'


def pairwise_offline_forward(
model: nn.Module,
tokenizer: Tokenizer,
batch: MutableMapping,
average_log_prob: bool = False,
policy_model_config: Optional[PretrainedConfig] = None,
use_attention_sequence_id: bool = False,
) -> dict[str, torch.Tensor]:
"""Forwards the model for dpo and get the chosen and rejected log probs.
Args:
model (nn.Module): Model we are forwarding.
tokenizer (Tokenizer): Tokenizer for the model.
batch (Dict[str, torch.LongTensor]): Batch over which we should forward the model.
Note: this batch has chosen and rejected concated along the sequence dimension.
average_log_prob (bool): Whether should we average the log probabilities.
policy_model_config: Policy model config.
use_attention_sequence_id (bool): Whether we should use the attention sequence id.
"""
if policy_model_config is not None and hasattr(model, 'transformer'):
clear_mb_load_balancing_loss(
policy_model_config,
model.transformer, # type: ignore
)

batch_size, concat_seq_len = batch['input_ids'].shape
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
raise ValueError('Tokenizer must have a PAD token.')

# If we can use attention sequence ID, we use this logic branch.
# This is determined by a value set in `train_dpo.py`
if use_attention_sequence_id:
output_logits = model(
batch['input_ids'],
attention_mask=batch['attention_mask'],
sequence_id=batch['sequence_id'],
).logits

chosen_logits, rejected_logits = extract_packed_chosen_rejected(
output_logits,
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=pad_token_id, # type: ignore
)

else:
# If we can't use attn_seq_id then we need to unpack each batch and
# Pack along the batch dimension instead.

chosen_inputs, rejected_inputs = extract_packed_chosen_rejected(
batch['input_ids'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=pad_token_id, # type: ignore
)

chosen_attention_mask, rejected_attention_mask = extract_packed_chosen_rejected(
batch['attention_mask'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=0,
)

batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0)
batch_attn_mask = torch.cat(
[
chosen_attention_mask,
rejected_attention_mask,
],
dim=0,
)

output_logits = model(
batch_cat_inputs,
attention_mask=batch_attn_mask,
).logits

# Extract out the chosen and rejected logits along the batch dimension
chosen_logits = output_logits[:batch_size]
rejected_logits = output_logits[batch_size:]

chosen_labels, rejected_labels = extract_packed_chosen_rejected(
batch['input_ids'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=0,
)

chosen_logps = get_batch_logp(
chosen_labels,
chosen_logits,
batch['prompt_len'],
batch['chosen_len'],
average_log_prob,
)

rejected_logps = get_batch_logp(
rejected_labels,
rejected_logits,
batch['prompt_len'],
batch['rejected_len'],
average_log_prob,
)

outputs: dict[str, torch.Tensor] = {
'policy_chosen_logp': chosen_logps,
'policy_rejected_logp': rejected_logps,
'chosen_len': batch['chosen_len'],
}

if 'chosen_reward' in batch:
outputs['chosen_reward'] = batch['chosen_reward']
outputs['rejected_reward'] = batch['rejected_reward']

if policy_model_config is not None and hasattr(model, 'transformer'):
lbl = get_mb_load_balancing_loss(
policy_model_config,
model.transformer, # type: ignore
)
if lbl is not None:
outputs['lbl'] = lbl

return outputs


def pairwise_offline_loss(
outputs: CausalLMOutputWithPast,
batch: Mapping,
loss_type: PairwiseOfflineEnum,
beta: float,
label_smoothing: float,
sft_alpha: float,
) -> dict[str, torch.Tensor]:
"""Computes pairwise offline RL losses.
Given precomputed values, the batch, and RL specific values, this will compute the specified loss.
Args:
outputs (CausalLMOutputWithPast): Outputs from forwarding the model over the batch.
batch (Mapping): Input batch of data.
loss_type (str): Loss type that we should compute (e.g. dpo, ipo, or kto),
beta (float): How much to regularize the policy model. We regularizethe policy less with
the reference model as beta -> 0.
label_smoothing: Represents conservativeness for the DPO loss. This assumes that
preferences as noisy (preferences are flipped with probability label_smoothing).
sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to
be added to DPO type loss.
"""
policy_chosen_logp = outputs['policy_chosen_logp']
policy_rejected_logp = outputs['policy_rejected_logp']
ref_chosen_logp = batch.get(
'ref_chosen',
torch.zeros_like(policy_chosen_logp),
)
ref_rejected_logp = batch.get(
'ref_rejected',
torch.zeros_like(policy_rejected_logp),
)

pi_logratios = policy_chosen_logp - policy_rejected_logp
ref_logratios = ref_chosen_logp - ref_rejected_logp

logits = pi_logratios - ref_logratios # Also known as h_{\pi_\theta}^{y_w,y_l}

losses = torch.zeros_like(logits)
if loss_type == PairwiseOfflineEnum.DPO:
losses = (
-F.logsigmoid(beta * logits) * (1 - label_smoothing) -
F.logsigmoid(-beta * logits) * label_smoothing
)
elif loss_type == PairwiseOfflineEnum.RCDPO:
# Adding reward-difference based label_smoothing = 1 - reward_bt_prob
chosen_reward = outputs['chosen_reward']
rejected_reward = outputs['rejected_reward']
reward_diff = chosen_reward - rejected_reward
reward_bt_prob = torch.sigmoid(reward_diff)
rcdpo_losses = -F.logsigmoid(
beta * logits,
) * reward_bt_prob - F.logsigmoid(
-beta * logits,
) * (1 - reward_bt_prob)
losses = rcdpo_losses
elif loss_type == PairwiseOfflineEnum.RPO:
# Reproducing the RPO loss from NVIDIA's paper: https://arxiv.org/pdf/2406.11704v1 page 13
# Code: https://github.com/NVIDIA/NeMo-Aligner/blob/c92a3bf9c2d6312581982a8d1db30591855394c5/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L261-L273
eta = 1 # NOTE: Hardcoding this to be 1 as per the paper's recommendation
chosen_reward = outputs['chosen_reward']
rejected_reward = outputs['rejected_reward']
reward_diff = chosen_reward - rejected_reward

logsigmoid_a = F.logsigmoid(beta * logits)
logsigmoid_b = F.logsigmoid(eta * reward_diff)
logsigmoid_not_a = F.logsigmoid(-beta * logits)
logsigmoid_not_b = F.logsigmoid(-eta * reward_diff)

losses = torch.exp(logsigmoid_b) * (
logsigmoid_b - logsigmoid_a
) + torch.exp(logsigmoid_not_b) * (logsigmoid_not_b - logsigmoid_not_a)
elif loss_type == PairwiseOfflineEnum.REBEL:
# Reproducing the REBEL loss from paper: https://arxiv.org/pdf/2404.16767 page 4
# Code: https://github.com/ZhaolinGao/REBEL/blob/e0a6a190108a45c70b4920b58a4ccac8a09ab22b/src/tldr/rebel.py#L761-L777
pi_logratios = policy_chosen_logp - policy_rejected_logp
ref_logratios = ref_chosen_logp - ref_rejected_logp

logits = pi_logratios - ref_logratios # Also known as h_{\pi_\theta}^{y_w,y_l}

chosen_reward = outputs['chosen_reward']
rejected_reward = outputs['rejected_reward']
reward_diff = chosen_reward - rejected_reward
losses = (beta * logits - reward_diff)**2
# beta represents 1/eta hparam from the paper
elif loss_type == PairwiseOfflineEnum.IPO:
losses = (logits - 1 / (2 * beta))**2
elif loss_type == PairwiseOfflineEnum.KTO:
chosen_KL = (policy_chosen_logp - ref_chosen_logp).mean().clamp(min=0)
rejected_KL = (policy_rejected_logp -
ref_rejected_logp).mean().clamp(min=0)

chosen_logratios = policy_chosen_logp - ref_chosen_logp
rejected_logratios = policy_rejected_logp - ref_rejected_logp
losses = torch.cat(
(
1 - F.sigmoid(beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(f'Loss type: {loss_type} is not supported.')

if sft_alpha > 0:
sft_losses = -1 * sft_alpha * policy_chosen_logp
sft_losses_normalized = sft_losses / outputs['chosen_len']
losses_before_sft = losses.clone().detach()
losses += sft_losses_normalized

losses = losses.mean()

chosen_rewards = beta * (policy_chosen_logp - ref_chosen_logp).detach()
rejected_rewards = beta * (policy_rejected_logp -
ref_rejected_logp).detach()

# Logging KL margins for comparing different methods
chosen_KL = (policy_chosen_logp - ref_chosen_logp).detach()
rejected_KL = (policy_rejected_logp - ref_rejected_logp).detach()
margin_KL = (chosen_KL - rejected_KL).detach()
loss_dict = {
'chosen_rewards': chosen_rewards,
'rejected_rewards': rejected_rewards,
'margin': chosen_rewards - rejected_rewards,
'chosen_KL': chosen_KL,
'rejected_KL': rejected_KL,
'margin_KL': margin_KL,
'accuracy': (chosen_rewards > rejected_rewards).to(torch.float32),
}
if loss_type in [
PairwiseOfflineEnum.RPO,
PairwiseOfflineEnum.RCDPO,
PairwiseOfflineEnum.REBEL,
]:
# reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL
loss_dict['reward_diff'] = reward_diff.detach() # type: ignore
if sft_alpha > 0:
# sft_losses_normalized is always defined if sft_alpha>0
snl = sft_losses_normalized.detach() # type: ignore
loss_dict['sft_regularization_loss'] = snl
# losses_before_sft is always defined if sft_alpha>0
loss_dict[f'{loss_type.value}_loss'] = losses_before_sft # type: ignore

if 'lbl' in outputs:
losses += outputs['lbl']
loss_dict['lbl'] = outputs['lbl']

loss_dict['total'] = losses

return loss_dict
37 changes: 37 additions & 0 deletions compose_rl/algorithms/online/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.algorithms.online.callback import OnPolicyCallback
from compose_rl.algorithms.online.kl_controller import (
AdaptiveKLController,
BallKLController,
FixedKLController,
KLPIDController,
)
from compose_rl.algorithms.online.model import (
ComposerHFCriticFreePolicyLM,
ComposerHFPolicyLM,
ComposerMPTPolicyLM,
)
from compose_rl.algorithms.online.model_methods import \
CausalLMOutputWithPastAndValues
from compose_rl.algorithms.online.policy_configuration import (
HFPolicyConfig,
MPTPolicyConfig,
)
from compose_rl.registry import kl_controllers

kl_controllers.register('adaptive', func=AdaptiveKLController)
kl_controllers.register('fixed', func=FixedKLController)
kl_controllers.register('pid', func=KLPIDController)
kl_controllers.register('ball', func=BallKLController)

__all__ = [
'OnPolicyCallback',
'ComposerMPTPolicyLM',
'ComposerHFPolicyLM',
'ComposerHFCriticFreePolicyLM',
'HFPolicyConfig',
'MPTPolicyConfig',
'CausalLMOutputWithPastAndValues',
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""PPO callback."""
"""Online On-Policy RL callback."""

from __future__ import annotations

@@ -29,35 +29,45 @@
from llmfoundry.interfaces import CallbackWithConfig
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

import compose_rl.utils as utils
from compose_rl.ppo.buffer import MinibatchRolloutBuffer
from compose_rl.ppo.generation_utils import hf_generate, vllm_generate
from compose_rl.ppo.model import ComposerHFPolicyModel, ComposerMosaicPolicy
from compose_rl.ppo.reward_manager import (
from compose_rl.algorithms.online.generation_utils import (
broadcast_to_vllm,
create_vllm_engines,
hf_generate,
init_process_group,
vllm_generate,
)
from compose_rl.algorithms.online.model import (
ComposerHFPolicyLM,
ComposerMPTPolicyLM,
)
from compose_rl.algorithms.online.model_methods import (
OnPolicyEnum,
)
from compose_rl.algorithms.online.reward_manager import (
ReferenceOutput,
RewardManager,
RewardOutput,
)
from compose_rl.data.buffer import MinibatchRolloutBuffer
from compose_rl.registry_builders import build_kl_controller
from compose_rl.utils import (
add_right_padding,
broadcast_to_vllm,
compute_advantages,
create_vllm_engines,
dist_compute_masked_mean_and_var,
flatten,
get_decoded_sequence,
get_entropies,
get_log_probs,
init_process_group,
mask_eos,
masked_mean,
masked_sum,
switch_left_to_right_padding,
)

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
Policy = Union[ComposerHFPolicyModel, ComposerMosaicPolicy]
Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM]

__all__ = ['PPOCallback', 'env_reward']
__all__ = ['OnPolicyCallback', 'env_reward']

log = logging.getLogger(__name__)

@@ -303,8 +313,8 @@ def env_reward(
)


class PPOCallback(CallbackWithConfig):
"""Callback for managing PPO training in an RLHF loop.
class OnPolicyCallback(CallbackWithConfig):
"""Callback for managing on-policy training in an RLHF loop.
Args:
train_config (dict): Training config passed to callback via foundry train.py as
@@ -489,7 +499,7 @@ def init(self, state: State, logger: Logger):
self.pad_token_idx = state.model.tokenizer.pad_token_id # type: ignore
self.actor_critic = state.model

if self.actor_critic.loss_type == 'grpo':
if self.actor_critic.loss_type == OnPolicyEnum.GRPO:
assert self.generations_per_prompt > 1, \
'GRPO requires multiple generations per prompt. ' + \
f'Current generations_per_prompt is: {self.generations_per_prompt}.'
@@ -667,7 +677,7 @@ def _get_next_iter_prompts(self):
ret_batch[key] = torch.cat(curr_values)
else:
if key == 'verified_answer':
ret_batch[key] = list(utils.flatten(curr_values))
ret_batch[key] = list(flatten(curr_values))
else:
# this is an edge case that we will not hit currently, but just handling it as needed
ret_batch[key] = curr_values
@@ -862,20 +872,20 @@ def _resolve_outputs(
)

# Now that rewards are resolved, we can compute advantages
if self.actor_critic.loss_type == 'ppo':
if self.actor_critic.loss_type == OnPolicyEnum.PPO:
env_outs['advantages'] = compute_advantages(
rewards=env_outs['rewards'],
values=env_outs['values'],
gamma=self.gamma,
lambda_gae=self.lambda_gae,
)
elif self.actor_critic.loss_type == 'grpo':
elif self.actor_critic.loss_type == OnPolicyEnum.GRPO:
# compute GRPO advantages
prompt_id = env_outs['prompt_id']
rewards = env_outs['rewards']

# Flatten the rewards by summing on sequence length/action_mask
flat_rewards = utils.masked_sum(
flat_rewards = masked_sum(
rewards,
env_outs['action_mask'],
dim=-1,
@@ -1107,6 +1117,7 @@ def _update_inference_model(self, batch: dict[str, torch.Tensor]):
self.vllm_engines,
self.model_update_group,
batch,
#loss_type=self.actor_critic.loss_type.value, # type: ignore
loss_type=self.actor_critic.loss_type, # type: ignore
)
log.info('Finished broadcasting to vLLM')
20 changes: 20 additions & 0 deletions compose_rl/algorithms/online/generation_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.algorithms.online.generation_utils.generation_utils import (
hf_generate,
vllm_generate,
)
from compose_rl.algorithms.online.generation_utils.vllm_utils import (
broadcast_to_vllm,
create_vllm_engines,
init_process_group,
)

__all__ = [
'broadcast_to_vllm',
'create_vllm_engines',
'init_process_group',
'hf_generate',
'vllm_generate',
]
Original file line number Diff line number Diff line change
@@ -13,14 +13,17 @@
from composer.utils import dist
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from compose_rl.ppo.model import ComposerHFPolicyModel, ComposerMosaicPolicy
from compose_rl.algorithms.online.model import (
ComposerHFPolicyLM,
ComposerMPTPolicyLM,
)
from compose_rl.utils import (
flip_pad_token_usage_for_generate,
flip_pad_token_usage_in_ffn,
)

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
Policy = Union[ComposerHFPolicyModel, ComposerMosaicPolicy]
Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM]

log = logging.getLogger(__name__)

File renamed without changes.
Original file line number Diff line number Diff line change
@@ -44,7 +44,8 @@
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from compose_rl.utils.vllm_actor import LLMRayActor
from compose_rl.algorithms.online.generation_utils.vllm_actor import LLMRayActor
from compose_rl.algorithms.online.model_methods import OnPolicyEnum

log = logging.getLogger(__name__)

@@ -343,7 +344,7 @@ def should_update_torch_module(
parsed_module_name: str,
full_param_name: str,
module: nn.Module,
loss_type: str,
loss_type: OnPolicyEnum,
valid_non_leaf_module_names: list[str],
):
"""Check if the module should be updated.
@@ -361,10 +362,10 @@ def should_update_torch_module(
if parsed_module_name not in valid_non_leaf_module_names:
return False

if loss_type == 'grpo':
if loss_type == OnPolicyEnum.GRPO:
return True

if loss_type == 'ppo' and 'lm_backbone' in full_param_name:
if loss_type == OnPolicyEnum.PPO and 'lm_backbone' in full_param_name:
return True

return False
@@ -375,7 +376,7 @@ def broadcast_to_vllm(
vllm_engines: list,
model_update_group: Optional[torch.distributed.ProcessGroup],
batch: dict[str, torch.Tensor],
loss_type: str = 'ppo',
loss_type: OnPolicyEnum = OnPolicyEnum.PPO,
):
"""Broadcast model weights to all vllm engines.
@@ -388,12 +389,12 @@ def broadcast_to_vllm(
"""
# avoid OOM
torch.cuda.empty_cache()
if loss_type == 'ppo':
if loss_type == OnPolicyEnum.PPO:
# Extract the lm_backbone params from the model
count, num_params = 0, len(
list(model.model.lm_backbone.named_parameters()), # type: ignore
)
elif loss_type == 'grpo':
elif loss_type == OnPolicyEnum.GRPO:
# Directly use the model params
count, num_params = 0, len(
list(model.model.named_parameters()), # type: ignore
@@ -442,7 +443,7 @@ def broadcast_to_vllm(
if isinstance(module, FSDP):
# This is needed otherwise FSDP will materialize parameters of size 0.
# So just for the joint actor critic models we have to actually skip this module.
if module_name == 'model' and loss_type == 'ppo':
if module_name == 'model' and loss_type == OnPolicyEnum.PPO:
continue

# Only update if we haven't updated this module before
Original file line number Diff line number Diff line change
@@ -17,11 +17,11 @@
PreTrainedTokenizerFast,
)

from compose_rl.ppo.modeling_utils import (
from compose_rl.algorithms.online.model_methods import (
CausalLMOutputWithPastAndValues,
prepare_critic_values_for_training,
)
from compose_rl.ppo.policy_configuration import HFPolicyConfig
from compose_rl.algorithms.online.policy_configuration import HFPolicyConfig
from compose_rl.utils.consts import _MASTER_WEIGHTS_PRECISION

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
File renamed without changes.
41 changes: 23 additions & 18 deletions compose_rl/ppo/model.py → compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""PPO Composer Policy implementations."""
"""On-Policy Online RL Composer Model implementations."""

import collections
import logging
@@ -16,13 +16,14 @@
PreTrainedTokenizerFast,
)

from compose_rl.ppo.modeling_hf import ComposerHFPolicy
from compose_rl.ppo.modeling_mpt import MPTForPolicy
from compose_rl.ppo.modeling_utils import (
from compose_rl.algorithms.online.model_methods import (
OnPolicyEnum,
composer_online_rl_forward,
online_rl_loss,
)
from compose_rl.ppo.policy_configuration import MPTPolicyConfig
from compose_rl.algorithms.online.modeling_hf import ComposerHFPolicy
from compose_rl.algorithms.online.modeling_mpt import MPTForPolicy
from compose_rl.algorithms.online.policy_configuration import MPTPolicyConfig
from compose_rl.utils import (
clear_mb_load_balancing_loss,
get_mb_load_balancing_loss,
@@ -33,7 +34,7 @@
log = logging.getLogger(__name__)


class ComposerMosaicPolicy(HuggingFaceModel):
class ComposerMPTPolicyLM(HuggingFaceModel):

def __init__(
self,
@@ -75,7 +76,11 @@ def forward(self, batch: MutableMapping):
self.model.transformer, # type: ignore
)

ret_val = composer_online_rl_forward(batch, self.model)
ret_val = composer_online_rl_forward(
batch,
self.model,
OnPolicyEnum.PPO,
)

lbl = get_mb_load_balancing_loss(
self.config,
@@ -92,10 +97,10 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping):
)

def loss(self, outputs: MutableMapping, batch: MutableMapping):
return_dict, kl_loss = online_rl_loss(
return_dict = online_rl_loss(
outputs=outputs,
batch=batch,
loss_type='ppo',
loss_type=OnPolicyEnum.PPO,
value_clip_range=self.config.value_clip_range,
value_loss_weight=self.config.value_loss_weight,
policy_clip_ratio=self.config.policy_clip_ratio,
@@ -104,7 +109,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
kl_clip_range=self.config.kl_clip_range,
)

self.policy_kl.append(kl_loss)
self.policy_kl.append(return_dict['kl/policy_kl'])

return return_dict

@@ -125,7 +130,7 @@ def set_batch_stats(self, batch_stats: dict[str, Any]):
self.batch_stats = batch_stats # pyright: ignore


class ComposerHFPolicyModel(ComposerHFPolicy):
class ComposerHFPolicyLM(ComposerHFPolicy):

def __init__(
self,
@@ -154,7 +159,7 @@ def __init__(
self.compute_kl_loss = config_overrides.get('compute_kl_loss')
self.target_kl = config_overrides.get('target_kl')

self.loss_type = loss_type
self.loss_type = OnPolicyEnum(loss_type)

# Validating the input types
assert isinstance(self.compute_kl_loss, bool)
@@ -205,7 +210,7 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping):
)

def loss(self, outputs: MutableMapping, batch: MutableMapping):
return_dict, kl_loss = online_rl_loss(
return_dict = online_rl_loss(
outputs=outputs,
batch=batch,
loss_type=self.loss_type, # pyright: ignore
@@ -217,7 +222,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
kl_clip_range=self.config.kl_clip_range,
)

self.policy_kl.append(kl_loss)
self.policy_kl.append(return_dict['kl/policy_kl'])

return return_dict

@@ -238,7 +243,7 @@ def set_batch_stats(self, batch_stats: dict[str, Any]):
self.batch_stats = batch_stats


class ComposerHFCriticFreePolicyModel(ComposerHFCausalLM):
class ComposerHFCriticFreePolicyLM(ComposerHFCausalLM):
"""HF class wrapper for Critic Free Policy model."""
default_train_metrics: tuple = ()
default_eval_metrics: tuple = ()
@@ -271,7 +276,7 @@ def __init__(
"""
super().__init__(**kwargs)
self.policy_kl = []
self.loss_type = loss_type
self.loss_type = OnPolicyEnum(loss_type)
self.normalize_advantage = normalize_advantage
self.length_normalize_policy_loss = length_normalize_policy_loss
self.policy_clip_ratio = policy_clip_ratio
@@ -295,7 +300,7 @@ def eval_forward(self, batch: MutableMapping, outputs: MutableMapping):
)

def loss(self, outputs: MutableMapping, batch: MutableMapping):
return_dict, kl_loss = online_rl_loss(
return_dict = online_rl_loss(
outputs=outputs,
batch=batch,
loss_type=self.loss_type,
@@ -307,7 +312,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
kl_clip_range=self.kl_clip_range,
)

self.policy_kl.append(kl_loss)
self.policy_kl.append(return_dict['kl/policy_kl'])
return return_dict

def determine_early_stop(self):

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from compose_rl.ppo.hf_utils import AutoModelForCausalLMAsPolicy
from compose_rl.ppo.policy_configuration import HFPolicyConfig
from compose_rl.algorithms.online.hf_utils import AutoModelForCausalLMAsPolicy
from compose_rl.algorithms.online.policy_configuration import HFPolicyConfig

if TYPE_CHECKING:
from peft import PeftModel
@@ -29,7 +29,7 @@


class ComposerHFPolicy(BaseHuggingFaceModel):
"""Configures a :class:`.ComposerMosaicPolicy` as a Policy for PPO.
"""Configures a :class:`.ComposerMosaicPolicy` as a Policy for online RL.
See base class for argument documentation.
"""
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""MPT definition of a PPO Policy."""
"""MPT definition of an online RL Policy."""

import logging
from typing import Any, Optional
@@ -11,11 +11,11 @@
from llmfoundry.models import MPTForCausalLM
from transformers import PreTrainedModel

from compose_rl.ppo.modeling_utils import (
from compose_rl.algorithms.online.model_methods import (
CausalLMOutputWithPastAndValues,
prepare_critic_values_for_training,
)
from compose_rl.ppo.policy_configuration import MPTPolicyConfig
from compose_rl.algorithms.online.policy_configuration import MPTPolicyConfig

log = logging.getLogger(__name__)

File renamed without changes.
Original file line number Diff line number Diff line change
@@ -20,16 +20,16 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from compose_rl.interfaces.base_kl_controller import BaseKLController
from compose_rl.registry import rewards as rewards_registry
from compose_rl.registry_builders import build_reward
from compose_rl.reward_learning import (
from compose_rl.algorithms.reward_modeling import (
BadGenerationEndReward,
BaseReward,
InferenceRewardModel,
Reward,
RewardModel,
)
from compose_rl.interfaces.base_kl_controller import BaseKLController
from compose_rl.registry import rewards as rewards_registry
from compose_rl.registry_builders import build_reward
from compose_rl.utils import (
approx_kl,
batch_process_fine_granularities,
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.reward_learning.base_reward import (
from compose_rl.algorithms.reward_modeling.base_reward import (
BaseReward,
Reward,
RewardModel,
)
from compose_rl.reward_learning.functional import (
from compose_rl.algorithms.reward_modeling.functional import (
BadGenerationEndReward,
GSM8KFormatVeriferReward,
GSM8KVeriferReward,
@@ -16,12 +16,13 @@
OutputLengthReward,
ShortResponseReward,
)
from compose_rl.reward_learning.hf_utils import (
from compose_rl.algorithms.reward_modeling.hf_utils import (
AutoModelForCausalLMWithRM,
RewardModelConfig,
)
from compose_rl.reward_learning.inference_model import InferenceRewardModel
from compose_rl.reward_learning.model import (
from compose_rl.algorithms.reward_modeling.inference_model import \
InferenceRewardModel
from compose_rl.algorithms.reward_modeling.model import (
ComposerHFClassifierRewardModel,
ComposerHFPairwiseRewardModel,
ComposerMPTPairwiseRewardModel,
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -12,13 +12,13 @@

log = logging.getLogger(__name__)

from compose_rl.data.rlvr_utils import (
from compose_rl.algorithms.reward_modeling.base_reward import Reward, Tokenizer
from compose_rl.utils.rlvr_utils import (
is_equiv,
last_boxed_only_string,
normalize_final_answer,
remove_boxed,
)
from compose_rl.reward_learning.base_reward import Reward, Tokenizer


class IncreasingNumbersReward(Reward):
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -12,7 +12,10 @@
from composer.utils import dist
from mcli import config as mcli_config

from compose_rl.reward_learning.base_reward import RewardModel, Tokenizer
from compose_rl.algorithms.reward_modeling.base_reward import (
RewardModel,
Tokenizer,
)
from compose_rl.utils import get_remote_name

log = logging.getLogger(__name__)
Original file line number Diff line number Diff line change
@@ -9,19 +9,24 @@
import torch
from llmfoundry.models import ComposerMPTCausalLM

from compose_rl.reward_learning.base_reward import RewardModel, Tokenizer
from compose_rl.reward_learning.hf_utils import SequenceClassifierOutput
from compose_rl.reward_learning.model_methods import (
from compose_rl.algorithms.reward_modeling.base_reward import (
RewardModel,
Tokenizer,
)
from compose_rl.algorithms.reward_modeling.hf_utils import \
SequenceClassifierOutput
from compose_rl.algorithms.reward_modeling.model_methods import (
ClassifierRewardEnum,
PairwiseRewardEnum,
classifier_forward,
classifier_loss,
pairwise_forward,
pairwise_loss,
)
from compose_rl.reward_learning.modeling_hf import \
from compose_rl.algorithms.reward_modeling.modeling_hf import \
ComposerHFSequenceClassification
from compose_rl.reward_learning.modeling_mpt import MPTForSequenceClassification
from compose_rl.algorithms.reward_modeling.modeling_mpt import \
MPTForSequenceClassification

log = logging.getLogger(__name__)

Original file line number Diff line number Diff line change
@@ -15,7 +15,8 @@
PreTrainedTokenizerFast,
)

from compose_rl.reward_learning.hf_utils import SequenceClassifierOutput
from compose_rl.algorithms.reward_modeling.hf_utils import \
SequenceClassifierOutput
from compose_rl.utils import (
clear_mb_load_balancing_loss,
extract_packed_chosen_rejected,
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from compose_rl.reward_learning.hf_utils import (
from compose_rl.algorithms.reward_modeling.hf_utils import (
AutoModelForCausalLMWithRM,
RewardModelConfig,
)
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from llmfoundry.models.mpt.configuration_mpt import MPTConfig
from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM

from compose_rl.reward_learning.hf_utils import (
from compose_rl.algorithms.reward_modeling.hf_utils import (
SequenceClassifierOutput,
ValueHead,
)
24 changes: 6 additions & 18 deletions compose_rl/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.data.buffer import (
DummyDataset,
MinibatchRolloutBuffer,
)
from compose_rl.data.dataloader import (
build_finegrained_preference_dataloader,
build_pairwise_preference_dataloader,
@@ -11,30 +15,14 @@
pairwise_preference_dataset_collate_fn,
)
from compose_rl.data.prompt_data import prompt_dataset_collate_fn
from compose_rl.data.rlvr_utils import (
extract_gsm8k_answer,
extract_math_answer,
is_equiv,
last_boxed_only_string,
normalize_final_answer,
prepare_gsm8k_prompt,
prepare_math_prompt,
remove_boxed,
)

__all__ = [
'build_pairwise_preference_dataloader',
'build_finegrained_preference_dataloader',
'build_prompt_dataloader',
'extract_gsm8k_answer',
'DummyDataset',
'finegrained_preference_dataset_collate_fn',
'MinibatchRolloutBuffer',
'pairwise_preference_dataset_collate_fn',
'prepare_gsm8k_prompt',
'prompt_dataset_collate_fn',
'extract_math_answer',
'prepare_math_prompt',
'last_boxed_only_string',
'remove_boxed',
'is_equiv',
'normalize_final_answer',
]
File renamed without changes.
9 changes: 0 additions & 9 deletions compose_rl/dpo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,2 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.dpo.callback import DPOCallback
from compose_rl.dpo.model import ComposerDPOLM, ComposerHFDPOLM

__all__ = [
'ComposerDPOLM',
'ComposerHFDPOLM',
'DPOCallback',
]
314 changes: 4 additions & 310 deletions compose_rl/dpo/model_methods.py
Original file line number Diff line number Diff line change
@@ -1,313 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""DPO Utils."""

from enum import Enum
from typing import Mapping, MutableMapping, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

from compose_rl.utils import (
clear_mb_load_balancing_loss,
extract_packed_chosen_rejected,
get_batch_logp,
get_mb_load_balancing_loss,
)


class DPOEnum(Enum):
DPO = 'dpo'
RPO = 'rpo'
RCDPO = 'rcdpo'
REBEL = 'rebel'
IPO = 'ipo'
KTO = 'kto'


def dpo_forward(
model: nn.Module,
tokenizer: Tokenizer,
batch: MutableMapping,
average_log_prob: bool = False,
policy_model_config: Optional[PretrainedConfig] = None,
use_attention_sequence_id: bool = False,
) -> dict[str, torch.Tensor]:
"""Forwards the model for dpo and get the chosen and rejected log probs.
Args:
model (nn.Module): Model we are forwarding.
tokenizer (Tokenizer): Tokenizer for the model.
batch (Dict[str, torch.LongTensor]): Batch over which we should forward the model.
Note: this batch has chosen and rejected concated along the sequence dimension.
average_log_prob (bool): Whether should we average the log probabilities.
policy_model_config: Policy model config.
use_attention_sequence_id (bool): Whether we should use the attention sequence id.
"""
if policy_model_config is not None and hasattr(model, 'transformer'):
clear_mb_load_balancing_loss(
policy_model_config,
model.transformer, # type: ignore
)

batch_size, concat_seq_len = batch['input_ids'].shape
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
raise ValueError('Tokenizer must have a PAD token.')

# If we can use attention sequence ID, we use this logic branch.
# This is determined by a value set in `train_dpo.py`
if use_attention_sequence_id:
output_logits = model(
batch['input_ids'],
attention_mask=batch['attention_mask'],
sequence_id=batch['sequence_id'],
).logits

chosen_logits, rejected_logits = extract_packed_chosen_rejected(
output_logits,
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=pad_token_id, # type: ignore
)

else:
# If we can't use attn_seq_id then we need to unpack each batch and
# Pack along the batch dimension instead.

chosen_inputs, rejected_inputs = extract_packed_chosen_rejected(
batch['input_ids'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=pad_token_id, # type: ignore
)

chosen_attention_mask, rejected_attention_mask = extract_packed_chosen_rejected(
batch['attention_mask'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=0,
)

batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0)
batch_attn_mask = torch.cat(
[
chosen_attention_mask,
rejected_attention_mask,
],
dim=0,
)

output_logits = model(
batch_cat_inputs,
attention_mask=batch_attn_mask,
).logits

# Extract out the chosen and rejected logits along the batch dimension
chosen_logits = output_logits[:batch_size]
rejected_logits = output_logits[batch_size:]

chosen_labels, rejected_labels = extract_packed_chosen_rejected(
batch['input_ids'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=0,
)

chosen_logps = get_batch_logp(
chosen_labels,
chosen_logits,
batch['prompt_len'],
batch['chosen_len'],
average_log_prob,
)

rejected_logps = get_batch_logp(
rejected_labels,
rejected_logits,
batch['prompt_len'],
batch['rejected_len'],
average_log_prob,
)

outputs: dict[str, torch.Tensor] = {
'policy_chosen_logp': chosen_logps,
'policy_rejected_logp': rejected_logps,
'chosen_len': batch['chosen_len'],
}

if 'chosen_reward' in batch:
outputs['chosen_reward'] = batch['chosen_reward']
outputs['rejected_reward'] = batch['rejected_reward']

if policy_model_config is not None and hasattr(model, 'transformer'):
lbl = get_mb_load_balancing_loss(
policy_model_config,
model.transformer, # type: ignore
)
if lbl is not None:
outputs['lbl'] = lbl

return outputs


def dpo_loss(
outputs: CausalLMOutputWithPast,
batch: Mapping,
loss_type: DPOEnum,
beta: float,
label_smoothing: float,
sft_alpha: float,
) -> dict[str, torch.Tensor]:
"""Computes DPO loss.
Given precomputed values, the batch, and dpo-oriented specific values, this will compute the dpo_loss.
Args:
outputs (CausalLMOutputWithPast): Outputs from forwarding the model over the batch.
batch (Mapping): Input batch of data.
loss_type (str): Loss type that we should compute (e.g. dpo, ipo, or kto),
beta (float): How much to regularize the policy model. We regularizethe policy less with
the reference model as beta -> 0.
label_smoothing: Represents conservativeness for the DPO loss. This assumes that
preferences as noisy (preferences are flipped with probability label_smoothing).
sft_alpha (float): Regularization weight for supervised finetuning loss (SFT) to
be added to DPO type loss.
"""
policy_chosen_logp = outputs['policy_chosen_logp']
policy_rejected_logp = outputs['policy_rejected_logp']
ref_chosen_logp = batch.get(
'ref_chosen',
torch.zeros_like(policy_chosen_logp),
)
ref_rejected_logp = batch.get(
'ref_rejected',
torch.zeros_like(policy_rejected_logp),
)

pi_logratios = policy_chosen_logp - policy_rejected_logp
ref_logratios = ref_chosen_logp - ref_rejected_logp

logits = pi_logratios - ref_logratios # Also known as h_{\pi_\theta}^{y_w,y_l}

losses = torch.zeros_like(logits)
if loss_type == DPOEnum.DPO:
losses = (
-F.logsigmoid(beta * logits) * (1 - label_smoothing) -
F.logsigmoid(-beta * logits) * label_smoothing
)
elif loss_type == DPOEnum.RCDPO:
# Adding reward-difference based label_smoothing = 1 - reward_bt_prob
chosen_reward = outputs['chosen_reward']
rejected_reward = outputs['rejected_reward']
reward_diff = chosen_reward - rejected_reward
reward_bt_prob = torch.sigmoid(reward_diff)
rcdpo_losses = -F.logsigmoid(
beta * logits,
) * reward_bt_prob - F.logsigmoid(
-beta * logits,
) * (1 - reward_bt_prob)
losses = rcdpo_losses
elif loss_type == DPOEnum.RPO:
# Reproducing the RPO loss from NVIDIA's paper: https://arxiv.org/pdf/2406.11704v1 page 13
# Code: https://github.com/NVIDIA/NeMo-Aligner/blob/c92a3bf9c2d6312581982a8d1db30591855394c5/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L261-L273
eta = 1 # NOTE: Hardcoding this to be 1 as per the paper's recommendation
chosen_reward = outputs['chosen_reward']
rejected_reward = outputs['rejected_reward']
reward_diff = chosen_reward - rejected_reward

logsigmoid_a = F.logsigmoid(beta * logits)
logsigmoid_b = F.logsigmoid(eta * reward_diff)
logsigmoid_not_a = F.logsigmoid(-beta * logits)
logsigmoid_not_b = F.logsigmoid(-eta * reward_diff)

losses = torch.exp(logsigmoid_b) * (
logsigmoid_b - logsigmoid_a
) + torch.exp(logsigmoid_not_b) * (logsigmoid_not_b - logsigmoid_not_a)
elif loss_type == DPOEnum.REBEL:
# Reproducing the REBEL loss from paper: https://arxiv.org/pdf/2404.16767 page 4
# Code: https://github.com/ZhaolinGao/REBEL/blob/e0a6a190108a45c70b4920b58a4ccac8a09ab22b/src/tldr/rebel.py#L761-L777
pi_logratios = policy_chosen_logp - policy_rejected_logp
ref_logratios = ref_chosen_logp - ref_rejected_logp

logits = pi_logratios - ref_logratios # Also known as h_{\pi_\theta}^{y_w,y_l}

chosen_reward = outputs['chosen_reward']
rejected_reward = outputs['rejected_reward']
reward_diff = chosen_reward - rejected_reward
losses = (beta * logits - reward_diff)**2
# beta represents 1/eta hparam from the paper
elif loss_type == DPOEnum.IPO:
losses = (logits - 1 / (2 * beta))**2
elif loss_type == DPOEnum.KTO:
chosen_KL = (policy_chosen_logp - ref_chosen_logp).mean().clamp(min=0)
rejected_KL = (policy_rejected_logp -
ref_rejected_logp).mean().clamp(min=0)

chosen_logratios = policy_chosen_logp - ref_chosen_logp
rejected_logratios = policy_rejected_logp - ref_rejected_logp
losses = torch.cat(
(
1 - F.sigmoid(beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(f'Loss type: {loss_type} is not supported.')
if sft_alpha > 0:
sft_losses = -1 * sft_alpha * policy_chosen_logp
sft_losses_normalized = sft_losses / outputs['chosen_len']
losses_before_sft = losses.clone().detach()
losses += sft_losses_normalized

losses = losses.mean()

chosen_rewards = beta * (policy_chosen_logp - ref_chosen_logp).detach()
rejected_rewards = beta * (policy_rejected_logp -
ref_rejected_logp).detach()

# Logging KL margins for comparing different methods
chosen_KL = (policy_chosen_logp - ref_chosen_logp).detach()
rejected_KL = (policy_rejected_logp - ref_rejected_logp).detach()
margin_KL = (chosen_KL - rejected_KL).detach()
loss_dict = {
'chosen_rewards': chosen_rewards,
'rejected_rewards': rejected_rewards,
'margin': chosen_rewards - rejected_rewards,
'chosen_KL': chosen_KL,
'rejected_KL': rejected_KL,
'margin_KL': margin_KL,
'accuracy': (chosen_rewards > rejected_rewards).to(torch.float32),
}
if loss_type in [DPOEnum.RPO, DPOEnum.RCDPO, DPOEnum.REBEL]:
# reward_diff is always defined if loss_type is RPO, RCDPO, or REBEL
loss_dict['reward_diff'] = reward_diff.detach() # type: ignore
if sft_alpha > 0:
# sft_losses_normalized is always defined if sft_alpha>0
snl = sft_losses_normalized.detach() # type: ignore
loss_dict['sft_regularization_loss'] = snl
# losses_before_sft is always defined if sft_alpha>0
loss_dict[f'{loss_type.value}_loss'] = losses_before_sft # type: ignore

if 'lbl' in outputs:
losses += outputs['lbl']
loss_dict['lbl'] = outputs['lbl']

loss_dict['total'] = losses

return loss_dict
from compose_rl.algorithms.offline.model_methods import \
PairwiseOfflineEnum as DPOEnum # pyright: ignore
from compose_rl.algorithms.offline.model_methods import \
pairwise_offline_loss as dpo_loss # pyright: ignore
34 changes: 2 additions & 32 deletions compose_rl/ppo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,5 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.ppo.callback import PPOCallback
from compose_rl.ppo.kl_controller import (
AdaptiveKLController,
BallKLController,
FixedKLController,
KLPIDController,
)
from compose_rl.ppo.load_planner import PPOModelLoadPlanner
from compose_rl.ppo.model import (
ComposerHFCriticFreePolicyModel,
ComposerHFPolicyModel,
ComposerMosaicPolicy,
)
from compose_rl.ppo.modeling_utils import CausalLMOutputWithPastAndValues
from compose_rl.ppo.policy_configuration import HFPolicyConfig, MPTPolicyConfig
from compose_rl.registry import kl_controllers

kl_controllers.register('adaptive', func=AdaptiveKLController)
kl_controllers.register('fixed', func=FixedKLController)
kl_controllers.register('pid', func=KLPIDController)
kl_controllers.register('ball', func=BallKLController)

__all__ = [
'PPOCallback',
'ComposerMosaicPolicy',
'ComposerHFPolicyModel',
'ComposerHFCriticFreePolicyModel',
'HFPolicyConfig',
'MPTPolicyConfig',
'CausalLMOutputWithPastAndValues',
'PPOModelLoadPlanner',
]
from compose_rl.algorithms.online.callback import \
OnPolicyCallback as PPOCallback # pyright: ignore
4 changes: 2 additions & 2 deletions compose_rl/registry_builders.py
Original file line number Diff line number Diff line change
@@ -12,8 +12,8 @@
)

from compose_rl import registry
from compose_rl.ppo.kl_controller import BaseKLController
from compose_rl.reward_learning import BaseReward
from compose_rl.algorithms.online.kl_controller import BaseKLController
from compose_rl.algorithms.reward_modeling import BaseReward

__all__ = ['build_kl_controller', 'build_reward']

26 changes: 18 additions & 8 deletions compose_rl/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.utils.rlvr_utils import (
extract_gsm8k_answer,
extract_math_answer,
is_equiv,
last_boxed_only_string,
normalize_final_answer,
prepare_gsm8k_prompt,
prepare_math_prompt,
remove_boxed,
)
from compose_rl.utils.utils import (
add_right_padding,
approx_kl,
@@ -39,11 +49,6 @@
split_text_to_subsentences,
switch_left_to_right_padding,
)
from compose_rl.utils.vllm_utils import (
broadcast_to_vllm,
create_vllm_engines,
init_process_group,
)

__all__ = [
'get_mb_load_balancing_loss',
@@ -80,9 +85,14 @@
'make_padded_tensor',
'get_batch_logp',
'make_action_mask',
'create_vllm_engines',
'init_process_group',
'broadcast_to_vllm',
'flatten',
'sample_wise_masked_mean',
'extract_gsm8k_answer',
'extract_math_answer',
'is_equiv',
'last_boxed_only_string',
'normalize_final_answer',
'prepare_gsm8k_prompt',
'prepare_math_prompt',
'remove_boxed',
]
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
log = logging.getLogger(__name__)


class PPOModelLoadPlanner(DefaultLoadPlanner):
class ActorCriticModelLoadPlanner(DefaultLoadPlanner):

def create_local_plan(self):
self.metadata_has_critic_key = False # type: ignore
File renamed without changes.
32 changes: 21 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -40,30 +40,40 @@ cpu_released = [

# Registry entry points
[project.entry-points."llmfoundry_models"]
mpt_dpo_lm = "compose_rl.dpo:ComposerDPOLM"
hf_dpo_lm = "compose_rl.dpo:ComposerHFDPOLM"
mpt_pairwise_rm = "compose_rl.reward_learning:ComposerMPTPairwiseRewardModel"
hf_pairwise_rm = "compose_rl.reward_learning:ComposerHFPairwiseRewardModel"
hf_classifier_rm = "compose_rl.reward_learning:ComposerHFClassifierRewardModel"
mpt_ppo_lm = "compose_rl.ppo:ComposerMosaicPolicy"
hf_ppo_lm = "compose_rl.ppo:ComposerHFPolicyModel"
hf_critic_free_lm = "compose_rl.ppo:ComposerHFCriticFreePolicyModel"
mpt_pairwise_rm = "compose_rl.algorithms.reward_modeling:ComposerMPTPairwiseRewardModel"
hf_pairwise_rm = "compose_rl.algorithms.reward_modeling:ComposerHFPairwiseRewardModel"
hf_classifier_rm = "compose_rl.algorithms.reward_modeling:ComposerHFClassifierRewardModel"
mpt_pairwise_offline_lm = "compose_rl.algorithms.offline:ComposerMPTPairwiseOfflinePolicyLM"
hf_pairwise_offline_lm = "compose_rl.algorithms.offline:ComposerHFPairwiseOfflinePolicyLM"
mpt_actor_critic_lm = "compose_rl.algorithms.online:ComposerMPTPolicyLM"
hf_actor_critic_lm = "compose_rl.algorithms.online:ComposerHFPolicyLM"
hf_critic_free_lm = "compose_rl.algorithms.online:ComposerHFCriticFreePolicyLM"
# Backwards Compatibility
mpt_dpo_lm = "compose_rl.algorithms.offline:ComposerMPTPairwiseOfflinePolicyLM"
hf_dpo_lm = "compose_rl.algorithms.offline:ComposerHFPairwiseOfflinePolicyLM"
mpt_ppo_lm = "compose_rl.algorithms.online:ComposerMPTPolicyLM"
hf_ppo_lm = "compose_rl.algorithms.online:ComposerHFPolicyLM"

[project.entry-points."llmfoundry_dataloaders"]
pairwise_preference = "compose_rl.data:build_pairwise_preference_dataloader"
finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloader"
prompt = "compose_rl.data:build_prompt_dataloader"

[project.entry-points."llmfoundry_callbacks_with_config"]
dpo = "compose_rl.dpo:DPOCallback"
ppo = "compose_rl.ppo:PPOCallback"
offline_rl = "compose_rl.algorithms.offline:ReferencePolicyCallback"
on_policy_rl = "compose_rl.algorithms.online:OnPolicyCallback"
# Backwards Compatibility
dpo = "compose_rl.algorithms.offline:ReferencePolicyCallback"
ppo = "compose_rl.algorithms.online:OnPolicyCallback"

[project.entry-points."llmfoundry_metrics"]
pairwise_rm_accuracy = "compose_rl.metrics.reward_model_metrics:PairwiseRewardClassificationAccuracy"
classifier_accuracy = "compose_rl.metrics.reward_model_metrics:BinaryRewardClassificationAccuracy"

[project.entry-points."llmfoundry_load_planners"]
ppo_load_planner = "compose_rl.ppo.load_planner:PPOModelLoadPlanner"
actor_critic_load_planner = "compose_rl.utils.load_planner:ActorCriticModelLoadPlanner"
# Backwards Compatibility
ppo_load_planner = "compose_rl.utils.load_planner:ActorCriticModelLoadPlanner"

# iSort
[tool.isort]
2 changes: 1 addition & 1 deletion scripts/data/unified_tokenize_dataset.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from compose_rl.data.rlvr_utils import (
from compose_rl.utils.rlvr_utils import (
extract_gsm8k_answer,
extract_math_answer,
prepare_gsm8k_prompt,
2 changes: 1 addition & 1 deletion tests/functional_rewards/test_bad_generation.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import BadGenerationEndReward
from compose_rl.algorithms.reward_modeling import BadGenerationEndReward


@pytest.fixture
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import GSM8KFormatVeriferReward
from compose_rl.algorithms.reward_modeling import GSM8KFormatVeriferReward


@pytest.fixture
2 changes: 1 addition & 1 deletion tests/functional_rewards/test_gsm8k_verifier_reward.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import GSM8KVeriferReward
from compose_rl.algorithms.reward_modeling import GSM8KVeriferReward


@pytest.fixture
2 changes: 1 addition & 1 deletion tests/functional_rewards/test_increasing_numbers.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import IncreasingNumbersReward
from compose_rl.algorithms.reward_modeling import IncreasingNumbersReward


@pytest.fixture
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import MATHFormatVerifierReward
from compose_rl.algorithms.reward_modeling import MATHFormatVerifierReward


@pytest.fixture
2 changes: 1 addition & 1 deletion tests/functional_rewards/test_math_verifier_reward.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import MATHVerifierReward
from compose_rl.algorithms.reward_modeling import MATHVerifierReward


@pytest.fixture
2 changes: 1 addition & 1 deletion tests/functional_rewards/test_output_length_reward.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import OutputLengthReward
from compose_rl.algorithms.reward_modeling import OutputLengthReward


@pytest.fixture
2 changes: 1 addition & 1 deletion tests/functional_rewards/test_short_response_reward.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import torch
from transformers import AutoTokenizer

from compose_rl.reward_learning import ShortResponseReward
from compose_rl.algorithms.reward_modeling import ShortResponseReward


@pytest.fixture
36 changes: 19 additions & 17 deletions tests/test_dpo.py → tests/test_offline.py
Original file line number Diff line number Diff line change
@@ -15,13 +15,15 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from compose_rl.algorithms.offline import ComposerMPTPairwiseOfflinePolicyLM
from compose_rl.algorithms.offline.callback import ReferencePolicyCallback
from compose_rl.data import pairwise_preference_dataset_collate_fn
from compose_rl.dpo import ComposerDPOLM
from compose_rl.dpo.callback import DPOCallback
from tests.common import PairwisePreference, world_size


def test_dpo_callback_forward(tiny_gpt2_tokenizer: PreTrainedTokenizer):
def test_reference_policy_callback_forward(
tiny_gpt2_tokenizer: PreTrainedTokenizer,
):
# Build DataLoader
max_seq_len = 10
dataset = PairwisePreference(max_seq_len=max_seq_len)
@@ -45,14 +47,14 @@ def test_dpo_callback_forward(tiny_gpt2_tokenizer: PreTrainedTokenizer):
'loss_fn': 'torch_crossentropy',
'tokenizer': tiny_gpt2_tokenizer,
}
model = ComposerDPOLM(**model_config)
model_config['name'] = 'mpt_dpo_lm'
model = ComposerMPTPairwiseOfflinePolicyLM(**model_config)
model_config['name'] = 'mpt_pairwise_offline_lm'
train_config = {
'model': model_config,
'fsdp_config': {},
'seed': 17,
}
callback = DPOCallback(train_config=train_config)
callback = ReferencePolicyCallback(train_config=train_config)
Trainer(
model=model,
callbacks=callback,
@@ -92,7 +94,7 @@ def test_model_forward(tiny_gpt2_tokenizer: PreTrainedTokenizer):
'loss_fn': 'torch_crossentropy',
'tokenizer': tiny_gpt2_tokenizer,
}
model = ComposerDPOLM(**model_config)
model = ComposerMPTPairwiseOfflinePolicyLM(**model_config)
for sample in dataloader:
output = model(sample)
assert output is not None
@@ -125,8 +127,8 @@ def test_train(
},
'tokenizer': tiny_gpt2_tokenizer,
}
model = ComposerDPOLM(**model_config)
model_config['name'] = 'mpt_dpo_lm'
model = ComposerMPTPairwiseOfflinePolicyLM(**model_config)
model_config['name'] = 'mpt_pairwise_offline_lm'
fsdp_config = {}
train_config = {
'model': model_config,
@@ -136,7 +138,7 @@ def test_train(
trainer = Trainer(
model=model,
train_dataloader=dataloader,
callbacks=DPOCallback(train_config=train_config),
callbacks=ReferencePolicyCallback(train_config=train_config),
parallelism_config={'fsdp': fsdp_config},
max_duration='1ep',
)
@@ -176,9 +178,9 @@ def test_checkpoint_reloading(
}

# Making a dummy reference model so we can make sure the KL is 0
tmp_model = ComposerDPOLM(**model_config)
tmp_model = ComposerMPTPairwiseOfflinePolicyLM(**model_config)
tmp_optimizer = DecoupledAdamW(tmp_model.parameters(), lr=1e-6)
model_config['name'] = 'mpt_dpo_lm'
model_config['name'] = 'mpt_pairwise_offline_lm'
fsdp_config = {}
parallelism_config = {'fsdp': fsdp_config}

@@ -201,9 +203,9 @@ def test_checkpoint_reloading(
# After making the reference model, we can proceed with the DPO training
init_checkpoint_path = os.path.join(init_checkpoint_dir, 'latest-rank0.pt')

model = ComposerDPOLM(**model_config)
model = ComposerMPTPairwiseOfflinePolicyLM(**model_config)
# Add more model_config specific to DPO
model_config['name'] = 'mpt_dpo_lm'
model_config['name'] = 'mpt_pairwise_offline_lm'
model_config['loss_type'] = 'dpo'
model_config['beta'] = 0.1
model_config['sft_alpha'] = 0.2
@@ -222,7 +224,7 @@ def test_checkpoint_reloading(
model=model,
train_dataloader=dataloader,
loggers=in_memory_logger,
callbacks=DPOCallback(train_config=train_config),
callbacks=ReferencePolicyCallback(train_config=train_config),
parallelism_config={'fsdp': fsdp_config},
max_duration='8ba',
autoresume=True,
@@ -239,12 +241,12 @@ def test_checkpoint_reloading(

# Restart the training from the intermediate checkpoint
in_memory_logger = InMemoryLogger()
model = ComposerDPOLM(**model_config)
model = ComposerMPTPairwiseOfflinePolicyLM(**model_config)
trainer2 = Trainer(
model=model,
train_dataloader=dataloader,
loggers=in_memory_logger,
callbacks=DPOCallback(train_config=train_config),
callbacks=ReferencePolicyCallback(train_config=train_config),
parallelism_config={'fsdp': fsdp_config},
max_duration='8ba',
save_overwrite=True,
37 changes: 19 additions & 18 deletions tests/test_ppo.py → tests/test_online.py
Original file line number Diff line number Diff line change
@@ -17,13 +17,14 @@
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.models.gpt2 import GPT2LMHeadModel

from compose_rl.data import prompt_dataset_collate_fn
from compose_rl.ppo import (
ComposerHFPolicyModel,
ComposerMosaicPolicy,
PPOCallback,
from compose_rl.algorithms.online import (
ComposerHFPolicyLM,
ComposerMPTPolicyLM,
OnPolicyCallback,
)
from compose_rl.ppo.modeling_hf import ComposerHFPolicy
from compose_rl.algorithms.online.model_methods import OnPolicyEnum
from compose_rl.algorithms.online.modeling_hf import ComposerHFPolicy
from compose_rl.data import prompt_dataset_collate_fn
from tests.common import PromptDataset, VerifiablePromptDataset, world_size


@@ -41,13 +42,13 @@ def test_hf_ppo_model_construction(
'pretrained_model_name_or_path': local_save_path,
'pretrained': False,
'attn_implementation': 'sdpa',
'loss_type': 'hi',
'loss_type': 'ppo',
}
model = ComposerHFPolicyModel(**model_config)
assert isinstance(model, ComposerHFPolicyModel)
model = ComposerHFPolicyLM(**model_config)
assert isinstance(model, ComposerHFPolicyLM)
assert isinstance(model.model.lm_backbone, GPT2LMHeadModel)

assert model.loss_type == 'hi'
assert model.loss_type == OnPolicyEnum.PPO
assert model.model.lm_backbone.config._attn_implementation == 'sdpa'
assert model.shift_labels is True

@@ -104,15 +105,15 @@ def test_model_forward(
},
'tokenizer': tiny_gpt2_tokenizer,
}
model = ComposerMosaicPolicy(**model_config)
model = ComposerMPTPolicyLM(**model_config)
elif model_type == 'hf':
model_name = 'gpt2'
model_config = {
'tokenizer': tiny_gpt2_tokenizer,
'pretrained_model_name_or_path': model_name,
'pretrained': True,
}
model = ComposerHFPolicyModel(**model_config)
model = ComposerHFPolicyLM(**model_config)
else:
raise ValueError(f'Unknown model type: {model_type}')

@@ -217,9 +218,9 @@ def test_ppo_train(
tmp_ref_path = os.path.join(tmp_ref_path, 'latest-rank0.pt')

if model_type == 'mpt':
model = ComposerMosaicPolicy(**model_config)
model = ComposerMPTPolicyLM(**model_config)
elif model_type == 'hf':
model = ComposerHFPolicyModel(**model_config)
model = ComposerHFPolicyLM(**model_config)

optimizer = DecoupledAdamW(model.parameters(), lr=1e-8)

@@ -278,7 +279,7 @@ def test_ppo_train(
trainer = Trainer(
model=model,
optimizers=optimizer,
callbacks=PPOCallback(train_config=copy.deepcopy(train_config)),
callbacks=OnPolicyCallback(train_config=copy.deepcopy(train_config)),
train_dataloader=dataloader,
precision=precision,
parallelism_config={'fsdp': fsdp_config},
@@ -301,15 +302,15 @@ def test_ppo_train(

# Continue training for the remaining iterations
if model_type == 'mpt':
model = ComposerMosaicPolicy(**model_config)
model = ComposerMPTPolicyLM(**model_config)
elif model_type == 'hf':
model = ComposerHFPolicyModel(**model_config)
model = ComposerHFPolicyLM(**model_config)

optimizer = DecoupledAdamW(model.parameters(), lr=1e-8)
trainer_2 = Trainer(
model=model,
optimizers=optimizer,
callbacks=PPOCallback(train_config=copy.deepcopy(train_config)),
callbacks=OnPolicyCallback(train_config=copy.deepcopy(train_config)),
train_dataloader=dataloader,
precision=precision,
parallelism_config={'fsdp': fsdp_config},
4 changes: 2 additions & 2 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from llmfoundry.utils import registry_utils

from compose_rl import registry
from compose_rl.reward_learning.functional import OutputLengthReward
from compose_rl.algorithms.reward_modeling.functional import OutputLengthReward


def test_expected_registries_exist():
@@ -26,7 +26,7 @@ def test_expected_registries_exist():
def test_registry_init_code(tmp_path: pathlib.Path):
register_code = """
from compose_rl.registry import rewards
from compose_rl.reward_learning.functional import OutputLengthReward
from compose_rl.algorithms.reward_modeling.functional import OutputLengthReward
@rewards.register('test_reward')
class TestReward(OutputLengthReward):
19 changes: 12 additions & 7 deletions tests/test_reward_manager_timeout.py
Original file line number Diff line number Diff line change
@@ -25,8 +25,11 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from compose_rl.ppo.reward_manager import RewardManager, RewardOutput
from compose_rl.reward_learning import InferenceRewardModel
from compose_rl.algorithms.online.reward_manager import (
RewardManager,
RewardOutput,
)
from compose_rl.algorithms.reward_modeling import InferenceRewardModel


class MockAsyncResult(AsyncResult):
@@ -89,7 +92,9 @@ def mock_reward_manager(

# Create RewardManager with minimal initialization
with patch.object(RewardManager, 'initialize_composer_model') as mock_init:
with patch('compose_rl.ppo.reward_manager.spacy.load') as mock_spacy:
with patch(
'compose_rl.algorithms.online.reward_manager.spacy.load',
) as mock_spacy:
mock_ref_model = Mock()
mock_init.return_value = mock_ref_model

@@ -99,10 +104,10 @@ def mock_reward_manager(

# Patch the reward registry and build_reward
with patch(
'compose_rl.ppo.reward_manager.rewards_registry',
'compose_rl.algorithms.online.reward_manager.rewards_registry',
) as mock_registry:
with patch(
'compose_rl.ppo.reward_manager.build_reward',
'compose_rl.algorithms.online.reward_manager.build_reward',
) as mock_build:
mock_registry.get.return_value = InferenceRewardModel
mock_reward_model = MockRewardModel()
@@ -148,7 +153,7 @@ def test_async_timeout_creates_zero_reward(
mock_kl_ctl.value = 0.1

# Test resolve_outputs with timeout
with patch('compose_rl.ppo.reward_manager.log') as mock_log:
with patch('compose_rl.algorithms.online.reward_manager.log') as mock_log:
outputs = mock_reward_manager.resolve_outputs(
ref_output=ref_output,
reward_output=reward_output,
@@ -250,7 +255,7 @@ def test_mixed_timeout_and_success(mock_reward_manager: RewardManager) -> None:
mock_kl_ctl.value = 0.1

# Test resolve_outputs
with patch('compose_rl.ppo.reward_manager.log'):
with patch('compose_rl.algorithms.online.reward_manager.log'):
outputs = mock_reward_manager.resolve_outputs(
ref_output=ref_output,
reward_output=reward_output,
Original file line number Diff line number Diff line change
@@ -24,11 +24,12 @@
from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaAttention

from compose_rl.algorithms.reward_modeling.hf_utils import \
AutoModelForCausalLMWithRM
from compose_rl.data import (
finegrained_preference_dataset_collate_fn,
pairwise_preference_dataset_collate_fn,
)
from compose_rl.reward_learning.hf_utils import AutoModelForCausalLMWithRM
from tests.common import FineGrainedPreference, PairwisePreference, world_size


4 changes: 2 additions & 2 deletions yamls/local_dpo.yaml
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ run_name: local-llama8b-dpo
seed: 17
model:
beta: 0.01
name: hf_dpo_lm
name: hf_pairwise_offline_lm
loss_type: dpo
pretrained: true
use_auth_token: true
@@ -14,7 +14,7 @@ loggers:
experiment_name: brandon_dpo_test

callbacks:
dpo: {}
offline_rl: {}
lr_monitor: {}
speed_monitor:
window_size: 10
2 changes: 1 addition & 1 deletion yamls/local_grpo.yaml
Original file line number Diff line number Diff line change
@@ -181,7 +181,7 @@ callbacks:
# save_folder: # TODO: add save path for the hf checkpoint
# save_interval: 10iter

ppo: {}
on_policy_rl: {}

save_folder: /tmp/critic_free_models # TODO: fill in with an appropriate value
# only_composer_checkpoint: true
4 changes: 2 additions & 2 deletions yamls/local_ppo.yaml
Original file line number Diff line number Diff line change
@@ -71,7 +71,7 @@ variables:

model:
<<: *base_model
name: hf_ppo_lm
name: on_policy_rl
loss_type: ppo
config_overrides:
critic_dropout: 0.0
@@ -179,7 +179,7 @@ callbacks:
# save_folder: # TODO: add save path for the hf checkpoint
# save_interval: 10iter

ppo: {}
on_policy_rl: {}

save_folder: /tmp/ppo_models # TODO: fill in with an appropriate value
# only_composer_checkpoint: true