Skip to content

Commit 9fbe0b8

Browse files
Add configs and adapt exporter for RSL-RL distillation (isaac-sim#2182)
# Description This PR adds configuration classes for Student-Teacher Distillation and adapts the policy exporters to be able to export student policies. ## Type of change - Non-breaking change ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --------- Signed-off-by: Mayank Mittal <[email protected]> Co-authored-by: Mayank Mittal <[email protected]>
1 parent 5bccd80 commit 9fbe0b8

File tree

10 files changed

+212
-77
lines changed

10 files changed

+212
-77
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Guidelines for modifications:
4646
* Calvin Yu
4747
* Cheng-Rong Lai
4848
* Chenyu Yang
49+
* Clemens Schwarke
4950
* CY (Chien-Ying) Chen
5051
* David Yang
5152
* Dorsa Rohani

scripts/reinforcement_learning/rsl_rl/play.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,6 @@
55

66
"""Script to play a checkpoint if an RL agent from RSL-RL."""
77

8-
import platform
9-
from importlib.metadata import version
10-
11-
if version("rsl-rl-lib") != "2.3.0":
12-
if platform.system() == "Windows":
13-
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
14-
else:
15-
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
16-
print(
17-
f"Please install the correct version of RSL-RL.\nExisting version is: '{version('rsl-rl-lib')}'"
18-
" and required version is: '2.3.0'.\nTo install the correct version, run:"
19-
f"\n\n\t{' '.join(cmd)}\n"
20-
)
21-
exit(1)
22-
238
"""Launch Isaac Sim Simulator first."""
249

2510
import argparse
@@ -133,11 +118,20 @@ def main():
133118
# obtain the trained policy for inference
134119
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
135120

121+
# extract the neural network module
122+
# we do this in a try-except to maintain backwards compatibility.
123+
try:
124+
# version 2.3 onwards
125+
policy_nn = ppo_runner.alg.policy
126+
except AttributeError:
127+
# version 2.2 and below
128+
policy_nn = ppo_runner.alg.actor_critic
129+
136130
# export policy to onnx/jit
137131
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
138-
export_policy_as_jit(ppo_runner.alg.policy, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
132+
export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
139133
export_policy_as_onnx(
140-
ppo_runner.alg.policy, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
134+
policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
141135
)
142136

143137
dt = env.unwrapped.step_dt

scripts/reinforcement_learning/rsl_rl/train.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,6 @@
55

66
"""Script to train RL agent with RSL-RL."""
77

8-
import platform
9-
from importlib.metadata import version
10-
11-
if version("rsl-rl-lib") != "2.3.0":
12-
if platform.system() == "Windows":
13-
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
14-
else:
15-
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
16-
print(
17-
f"Please install the correct version of RSL-RL.\nExisting version is: '{version('rsl-rl-lib')}'"
18-
" and required version is: '2.3.0'.\nTo install the correct version, run:"
19-
f"\n\n\t{' '.join(cmd)}\n"
20-
)
21-
exit(1)
22-
238
"""Launch Isaac Sim Simulator first."""
249

2510
import argparse
@@ -60,6 +45,28 @@
6045
app_launcher = AppLauncher(args_cli)
6146
simulation_app = app_launcher.app
6247

48+
"""Check for minimum supported RSL-RL version."""
49+
50+
import importlib.metadata as metadata
51+
import platform
52+
53+
from packaging import version
54+
55+
# for distributed training, check minimum supported rsl-rl version
56+
RSL_RL_VERSION = "2.3.1"
57+
installed_version = metadata.version("rsl-rl-lib")
58+
if args_cli.distributed and version.parse(installed_version) < version.parse(RSL_RL_VERSION):
59+
if platform.system() == "Windows":
60+
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
61+
else:
62+
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
63+
print(
64+
f"Please install the correct version of RSL-RL.\nExisting version is: '{installed_version}'"
65+
f" and required version is: '{RSL_RL_VERSION}'.\nTo install the correct version, run:"
66+
f"\n\n\t{' '.join(cmd)}\n"
67+
)
68+
exit(1)
69+
6370
"""Rest everything follows."""
6471

6572
import gymnasium as gym
@@ -138,7 +145,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
138145
env = multi_agent_to_single_agent(env)
139146

140147
# save resume path before creating a new log_dir
141-
if agent_cfg.resume:
148+
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
142149
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
143150

144151
# wrap for video recording
@@ -161,7 +168,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
161168
# write git state to logs
162169
runner.add_git_repo_to_log(__file__)
163170
# load the checkpoint
164-
if agent_cfg.resume:
171+
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
165172
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
166173
# load previously trained model
167174
runner.load(resume_path)

source/isaaclab_rl/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
# Description
77
title = "Isaac Lab RL"

source/isaaclab_rl/docs/CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
Changelog
22
---------
33

4+
0.1.4 (2025-04-10)
5+
~~~~~~~~~~~~~~~~~~
6+
7+
Added
8+
^^^^^
9+
10+
* Added configurations for distillation implementation in RSL-RL.
11+
* Added configuration for recurrent actor-critic in RSL-RL.
12+
13+
414
0.1.3 (2025-03-31)
515
~~~~~~~~~~~~~~~~~~
616

source/isaaclab_rl/isaaclab_rl/rsl_rl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
1616
"""
1717

18+
from .distillation_cfg import *
1819
from .exporter import export_policy_as_jit, export_policy_as_onnx
19-
from .rl_cfg import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg
20+
from .rl_cfg import *
2021
from .rnd_cfg import RslRlRndCfg
2122
from .symmetry_cfg import RslRlSymmetryCfg
2223
from .vecenv_wrapper import RslRlVecEnvWrapper
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
from dataclasses import MISSING
9+
from typing import Literal
10+
11+
from isaaclab.utils import configclass
12+
13+
#########################
14+
# Policy configurations #
15+
#########################
16+
17+
18+
@configclass
19+
class RslRlDistillationStudentTeacherCfg:
20+
"""Configuration for the distillation student-teacher networks."""
21+
22+
class_name: str = "StudentTeacher"
23+
"""The policy class name. Default is StudentTeacher."""
24+
25+
init_noise_std: float = MISSING
26+
"""The initial noise standard deviation for the student policy."""
27+
28+
noise_std_type: Literal["scalar", "log"] = "scalar"
29+
"""The type of noise standard deviation for the policy. Default is scalar."""
30+
31+
student_hidden_dims: list[int] = MISSING
32+
"""The hidden dimensions of the student network."""
33+
34+
teacher_hidden_dims: list[int] = MISSING
35+
"""The hidden dimensions of the teacher network."""
36+
37+
activation: str = MISSING
38+
"""The activation function for the student and teacher networks."""
39+
40+
41+
@configclass
42+
class RslRlDistillationStudentTeacherRecurrentCfg(RslRlDistillationStudentTeacherCfg):
43+
"""Configuration for the distillation student-teacher recurrent networks."""
44+
45+
class_name: str = "StudentTeacherRecurrent"
46+
"""The policy class name. Default is StudentTeacherRecurrent."""
47+
48+
rnn_type: str = MISSING
49+
"""The type of the RNN network. Either "lstm" or "gru"."""
50+
51+
rnn_hidden_dim: int = MISSING
52+
"""The hidden dimension of the RNN network."""
53+
54+
rnn_num_layers: int = MISSING
55+
"""The number of layers of the RNN network."""
56+
57+
teacher_recurrent: bool = MISSING
58+
"""Whether the teacher network is recurrent too."""
59+
60+
61+
############################
62+
# Algorithm configurations #
63+
############################
64+
65+
66+
@configclass
67+
class RslRlDistillationAlgorithmCfg:
68+
"""Configuration for the distillation algorithm."""
69+
70+
class_name: str = "Distillation"
71+
"""The algorithm class name. Default is Distillation."""
72+
73+
num_learning_epochs: int = MISSING
74+
"""The number of updates performed with each sample."""
75+
76+
learning_rate: float = MISSING
77+
"""The learning rate for the student policy."""
78+
79+
gradient_length: int = MISSING
80+
"""The number of environment steps the gradient flows back."""

source/isaaclab_rl/isaaclab_rl/rsl_rl/exporter.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,34 @@
88
import torch
99

1010

11-
def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
11+
def export_policy_as_jit(policy: object, normalizer: object | None, path: str, filename="policy.pt"):
1212
"""Export policy into a Torch JIT file.
1313
1414
Args:
15-
actor_critic: The actor-critic torch module.
15+
policy: The policy torch module.
1616
normalizer: The empirical normalizer module. If None, Identity is used.
1717
path: The path to the saving directory.
1818
filename: The name of exported JIT file. Defaults to "policy.pt".
1919
"""
20-
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
20+
policy_exporter = _TorchPolicyExporter(policy, normalizer)
2121
policy_exporter.export(path, filename)
2222

2323

2424
def export_policy_as_onnx(
25-
actor_critic: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
25+
policy: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
2626
):
2727
"""Export policy into a Torch ONNX file.
2828
2929
Args:
30-
actor_critic: The actor-critic torch module.
30+
policy: The policy torch module.
3131
normalizer: The empirical normalizer module. If None, Identity is used.
3232
path: The path to the saving directory.
3333
filename: The name of exported ONNX file. Defaults to "policy.onnx".
3434
verbose: Whether to print the model summary. Defaults to False.
3535
"""
3636
if not os.path.exists(path):
3737
os.makedirs(path, exist_ok=True)
38-
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
38+
policy_exporter = _OnnxPolicyExporter(policy, normalizer, verbose)
3939
policy_exporter.export(path, filename)
4040

4141

@@ -47,12 +47,22 @@ def export_policy_as_onnx(
4747
class _TorchPolicyExporter(torch.nn.Module):
4848
"""Exporter of actor-critic into JIT file."""
4949

50-
def __init__(self, actor_critic, normalizer=None):
50+
def __init__(self, policy, normalizer=None):
5151
super().__init__()
52-
self.actor = copy.deepcopy(actor_critic.actor)
53-
self.is_recurrent = actor_critic.is_recurrent
52+
self.is_recurrent = policy.is_recurrent
53+
# copy policy parameters
54+
if hasattr(policy, "actor"):
55+
self.actor = copy.deepcopy(policy.actor)
56+
if self.is_recurrent:
57+
self.rnn = copy.deepcopy(policy.memory_a.rnn)
58+
elif hasattr(policy, "student"):
59+
self.actor = copy.deepcopy(policy.student)
60+
if self.is_recurrent:
61+
self.rnn = copy.deepcopy(policy.memory_s.rnn)
62+
else:
63+
raise ValueError("Policy does not have an actor/student module.")
64+
# set up recurrent network
5465
if self.is_recurrent:
55-
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
5666
self.rnn.cpu()
5767
self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
5868
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
@@ -94,13 +104,23 @@ def export(self, path, filename):
94104
class _OnnxPolicyExporter(torch.nn.Module):
95105
"""Exporter of actor-critic into ONNX file."""
96106

97-
def __init__(self, actor_critic, normalizer=None, verbose=False):
107+
def __init__(self, policy, normalizer=None, verbose=False):
98108
super().__init__()
99109
self.verbose = verbose
100-
self.actor = copy.deepcopy(actor_critic.actor)
101-
self.is_recurrent = actor_critic.is_recurrent
110+
self.is_recurrent = policy.is_recurrent
111+
# copy policy parameters
112+
if hasattr(policy, "actor"):
113+
self.actor = copy.deepcopy(policy.actor)
114+
if self.is_recurrent:
115+
self.rnn = copy.deepcopy(policy.memory_a.rnn)
116+
elif hasattr(policy, "student"):
117+
self.actor = copy.deepcopy(policy.student)
118+
if self.is_recurrent:
119+
self.rnn = copy.deepcopy(policy.memory_s.rnn)
120+
else:
121+
raise ValueError("Policy does not have an actor/student module.")
122+
# set up recurrent network
102123
if self.is_recurrent:
103-
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
104124
self.rnn.cpu()
105125
self.forward = self.forward_lstm
106126
# copy normalizer if exists

0 commit comments

Comments
 (0)