Skip to content

Add configs and adapt exporter for RSL-RL distillation #2182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0a0450d
sets device to local rank
Mayankm96 Mar 21, 2025
c106c32
adds rsl-rl multi-gpu to docs
Mayankm96 Mar 21, 2025
d0180dc
fixes for rsl-rl library
Mayankm96 Mar 28, 2025
b5a634f
hard fixes to version of rsl-rl
Mayankm96 Mar 28, 2025
762d496
adds other cfgs for rnd and symmetry
Mayankm96 Mar 28, 2025
efe635f
runs formatter
Mayankm96 Mar 28, 2025
da06bd5
updates changelog
Mayankm96 Mar 28, 2025
67f3c11
runs formatter
Mayankm96 Mar 28, 2025
71a8fe5
fixes expectation of symm function
Mayankm96 Mar 28, 2025
67be285
adds more description on symmetry
Mayankm96 Mar 28, 2025
485e1a0
adds more docs
Mayankm96 Mar 28, 2025
9c3d9d4
updates benchmark script as well
Mayankm96 Mar 28, 2025
83a59fc
updates feature table
Mayankm96 Mar 28, 2025
99d89ea
adds version check
Mayankm96 Mar 28, 2025
b744d6a
add configs and adapt exporter for distillation
ClemensSchwarke Mar 28, 2025
281c71f
remove resume argument for better distillation workflow
ClemensSchwarke Feb 5, 2025
2671598
add me to contributors
ClemensSchwarke Mar 28, 2025
9d14bee
Merge branch 'main' into feature/rsl_rl_2_3_0_adaptation
Mayankm96 Mar 31, 2025
63805aa
Revert "remove resume argument for better distillation workflow"
ClemensSchwarke Apr 9, 2025
b91ce9f
remove need for resume flag when training distillation
ClemensSchwarke Apr 9, 2025
a7f5366
separate configs
ClemensSchwarke Apr 9, 2025
4ac6433
restructure exporter
ClemensSchwarke Apr 9, 2025
a172702
Apply suggestions from code review
Mayankm96 Apr 10, 2025
35b0d84
updates version
Mayankm96 Apr 10, 2025
8b9a1e1
makes scripts backwards compatible
Mayankm96 Apr 10, 2025
d208670
bumps rsl-rl version
Mayankm96 Apr 10, 2025
10ec3ec
fixes typo
Mayankm96 Apr 10, 2025
5479358
Merge branch 'main' into feature/rsl_rl_2_3_0_adaptation
Mayankm96 Apr 10, 2025
2041684
fixes extension toml
Mayankm96 Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Guidelines for modifications:
* Calvin Yu
* Cheng-Rong Lai
* Chenyu Yang
* Clemens Schwarke
* CY (Chien-Ying) Chen
* David Yang
* Dorsa Rohani
Expand Down
4 changes: 2 additions & 2 deletions scripts/reinforcement_learning/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = multi_agent_to_single_agent(env)

# save resume path before creating a new log_dir
if agent_cfg.resume:
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)

# wrap for video recording
Expand All @@ -161,7 +161,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# write git state to logs
runner.add_git_repo_to_log(__file__)
# load the checkpoint
if agent_cfg.resume:
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model
runner.load(resume_path)
Expand Down
3 changes: 2 additions & 1 deletion source/isaaclab_rl/isaaclab_rl/rsl_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

"""

from .distillation_cfg import *
from .exporter import export_policy_as_jit, export_policy_as_onnx
from .rl_cfg import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg
from .rl_cfg import *
from .rnd_cfg import RslRlRndCfg
from .symmetry_cfg import RslRlSymmetryCfg
from .vecenv_wrapper import RslRlVecEnvWrapper
80 changes: 80 additions & 0 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/distillation_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from dataclasses import MISSING
from typing import Literal

from isaaclab.utils import configclass

#########################
# Policy configurations #
#########################


@configclass
class RslRlDistillationStudentTeacherCfg:
"""Configuration for the distillation student-teacher networks."""

class_name: str = "StudentTeacher"
"""The policy class name. Default is StudentTeacher."""

init_noise_std: float = MISSING
"""The initial noise standard deviation for the student policy."""

noise_std_type: Literal["scalar", "log"] = "scalar"
"""The type of noise standard deviation for the policy. Default is scalar."""

student_hidden_dims: list[int] = MISSING
"""The hidden dimensions of the student network."""

teacher_hidden_dims: list[int] = MISSING
"""The hidden dimensions of the teacher network."""

activation: str = MISSING
"""The activation function for the student and teacher networks."""


@configclass
class RslRlDistillationStudentTeacherRecurrentCfg(RslRlDistillationStudentTeacherCfg):
"""Configuration for the distillation student-teacher recurrent networks."""

class_name: str = "StudentTeacherRecurrent"
"""The policy class name. Default is StudentTeacherRecurrent."""

rnn_type: str = MISSING
"""The type of the RNN network. Either "lstm" or "gru"."""

rnn_hidden_dim: int = MISSING
"""The hidden dimension of the RNN network."""

rnn_num_layers: int = MISSING
"""The number of layers of the RNN network."""

teacher_recurrent: bool = MISSING
"""Whether the teacher network is recurrent too."""


############################
# Algorithm configurations #
############################


@configclass
class RslRlDistillationAlgorithmCfg:
"""Configuration for the distillation algorithm."""

class_name: str = "Distillation"
"""The algorithm class name. Default is Distillation."""

num_learning_epochs: int = MISSING
"""The number of updates performed with each sample."""

learning_rate: float = MISSING
"""The learning rate for the student policy."""

gradient_length: float = MISSING
"""The number of environment steps the gradient flows back."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClemensSchwarke should the parameter by default be 1?

Suggested change
"""The number of environment steps the gradient flows back."""
"""The number of rollout steps for gradient propagation.
This is useful for sequential training of recurrent student network.
"""

48 changes: 34 additions & 14 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,34 @@
import torch


def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
def export_policy_as_jit(policy: object, normalizer: object | None, path: str, filename="policy.pt"):
"""Export policy into a Torch JIT file.

Args:
actor_critic: The actor-critic torch module.
policy: The policy torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
"""
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
policy_exporter = _TorchPolicyExporter(policy, normalizer)
policy_exporter.export(path, filename)


def export_policy_as_onnx(
actor_critic: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
policy: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
):
"""Export policy into a Torch ONNX file.

Args:
actor_critic: The actor-critic torch module.
policy: The policy torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported ONNX file. Defaults to "policy.onnx".
verbose: Whether to print the model summary. Defaults to False.
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
policy_exporter = _OnnxPolicyExporter(policy, normalizer, verbose)
policy_exporter.export(path, filename)


Expand All @@ -47,12 +47,22 @@ def export_policy_as_onnx(
class _TorchPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into JIT file."""

def __init__(self, actor_critic, normalizer=None):
def __init__(self, policy, normalizer=None):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
self.is_recurrent = policy.is_recurrent
# copy policy parameters
if hasattr(policy, "actor"):
self.actor = copy.deepcopy(policy.actor)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_a.rnn)
elif hasattr(policy, "student"):
self.actor = copy.deepcopy(policy.student)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_s.rnn)
else:
raise ValueError("Policy does not have an actor/student module.")
# set up recurrent network
if self.is_recurrent:
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
Expand Down Expand Up @@ -94,13 +104,23 @@ def export(self, path, filename):
class _OnnxPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into ONNX file."""

def __init__(self, actor_critic, normalizer=None, verbose=False):
def __init__(self, policy, normalizer=None, verbose=False):
super().__init__()
self.verbose = verbose
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
self.is_recurrent = policy.is_recurrent
# copy policy parameters
if hasattr(policy, "actor"):
self.actor = copy.deepcopy(policy.actor)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_a.rnn)
elif hasattr(policy, "student"):
self.actor = copy.deepcopy(policy.student)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_s.rnn)
else:
raise ValueError("Policy does not have an actor/student module.")
# set up recurrent network
if self.is_recurrent:
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.forward = self.forward_lstm
# copy normalizer if exists
Expand Down
74 changes: 48 additions & 26 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/rl_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from dataclasses import MISSING
from typing import Literal

from isaaclab.utils import configclass

from .distillation_cfg import RslRlDistillationAlgorithmCfg, RslRlDistillationStudentTeacherCfg
from .rnd_cfg import RslRlRndCfg
from .symmetry_cfg import RslRlSymmetryCfg

#########################
# Policy configurations #
#########################


@configclass
class RslRlPpoActorCriticCfg:
Expand All @@ -36,23 +43,33 @@ class RslRlPpoActorCriticCfg:


@configclass
class RslRlPpoAlgorithmCfg:
"""Configuration for the PPO algorithm."""
class RslRlPpoActorCriticRecurrentCfg(RslRlPpoActorCriticCfg):
"""Configuration for the PPO actor-critic networks with recurrent layers."""

class_name: str = "PPO"
"""The algorithm class name. Default is PPO."""
class_name: str = "ActorCriticRecurrent"
"""The policy class name. Default is ActorCriticRecurrent."""

value_loss_coef: float = MISSING
"""The coefficient for the value loss."""
rnn_type: str = MISSING
"""The type of RNN to use. Either "lstm" or "gru"."""

use_clipped_value_loss: bool = MISSING
"""Whether to use clipped value loss."""
rnn_hidden_dim: int = MISSING
"""The dimension of the RNN layers."""

clip_param: float = MISSING
"""The clipping parameter for the policy."""
rnn_num_layers: int = MISSING
"""The number of RNN layers."""

entropy_coef: float = MISSING
"""The coefficient for the entropy loss."""

############################
# Algorithm configurations #
############################


@configclass
class RslRlPpoAlgorithmCfg:
"""Configuration for the PPO algorithm."""

class_name: str = "PPO"
"""The algorithm class name. Default is PPO."""

num_learning_epochs: int = MISSING
"""The number of learning epochs per update."""
Expand All @@ -72,12 +89,24 @@ class RslRlPpoAlgorithmCfg:
lam: float = MISSING
"""The lambda parameter for Generalized Advantage Estimation (GAE)."""

entropy_coef: float = MISSING
"""The coefficient for the entropy loss."""

desired_kl: float = MISSING
"""The desired KL divergence."""

max_grad_norm: float = MISSING
"""The maximum gradient norm."""

value_loss_coef: float = MISSING
"""The coefficient for the value loss."""

use_clipped_value_loss: bool = MISSING
"""Whether to use clipped value loss."""

clip_param: float = MISSING
"""The clipping parameter for the policy."""

normalize_advantage_per_mini_batch: bool = False
"""Whether to normalize the advantage per mini-batch. Default is False.

Expand All @@ -94,6 +123,11 @@ class RslRlPpoAlgorithmCfg:
"""


#########################
# Runner configurations #
#########################


@configclass
class RslRlOnPolicyRunnerCfg:
"""Configuration of the runner for on-policy algorithms."""
Expand All @@ -113,10 +147,10 @@ class RslRlOnPolicyRunnerCfg:
empirical_normalization: bool = MISSING
"""Whether to use empirical normalization."""

policy: RslRlPpoActorCriticCfg = MISSING
policy: RslRlPpoActorCriticCfg | RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""

algorithm: RslRlPpoAlgorithmCfg = MISSING
algorithm: RslRlPpoAlgorithmCfg | RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""

clip_actions: float | None = None
Expand All @@ -126,10 +160,6 @@ class RslRlOnPolicyRunnerCfg:
This clipping is performed inside the :class:`RslRlVecEnvWrapper` wrapper.
"""

##
# Checkpointing parameters
##

save_interval: int = MISSING
"""The number of iterations between saves."""

Expand All @@ -144,10 +174,6 @@ class RslRlOnPolicyRunnerCfg:
``{time-stamp}_{run_name}``.
"""

##
# Logging parameters
##

logger: Literal["tensorboard", "neptune", "wandb"] = "tensorboard"
"""The logger to use. Default is tensorboard."""

Expand All @@ -157,10 +183,6 @@ class RslRlOnPolicyRunnerCfg:
wandb_project: str = "isaaclab"
"""The wandb project name. Default is "isaaclab"."""

##
# Loading parameters
##

resume: bool = False
"""Whether to resume. Default is False."""

Expand Down