Skip to content

Commit 36e1d53

Browse files
author
Le Horizon
committed
minimum change for lower bounded value target (for episodic return, goal distance return, and n-step bootstrapped return)
1 parent 4d09f62 commit 36e1d53

File tree

11 files changed

+648
-119
lines changed

11 files changed

+648
-119
lines changed

alf/algorithms/data_transformer.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,25 @@ class HindsightExperienceTransformer(DataTransformer):
734734
of the current timestep.
735735
The exact field names can be provided via arguments to the class ``__init__``.
736736
737+
NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes
738+
reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc..
739+
See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details.
740+
737741
To use this class, add it to any existing data transformers, e.g. use this config if
738742
``ObservationNormalizer`` is an existing data transformer:
739743
740744
.. code-block:: python
741745
742-
ReplayBuffer.keep_episodic_info=True
743-
HindsightExperienceTransformer.her_proportion=0.8
744-
TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer]
746+
alf.config('ReplayBuffer', keep_episodic_info=True)
747+
alf.config(
748+
'HindsightExperienceTransformer',
749+
her_proportion=0.8
750+
)
751+
alf.config(
752+
'TrainerConfig',
753+
data_transformer_ctor=[
754+
HindsightExperienceTransformer, ObservationNormalizer
755+
])
745756
746757
See unit test for more details on behavior.
747758
"""
@@ -818,9 +829,10 @@ def transform_experience(self, experience: Experience):
818829
# relabel only these sampled indices
819830
her_cond = torch.rand(batch_size) < her_proportion
820831
(her_indices, ) = torch.where(her_cond)
832+
has_her = torch.any(her_cond)
821833

822-
last_step_pos = start_pos[her_indices] + batch_length - 1
823-
last_env_ids = env_ids[her_indices]
834+
last_step_pos = start_pos + batch_length - 1
835+
last_env_ids = env_ids
824836
# Get x, y indices of LAST steps
825837
dist = buffer.steps_to_episode_end(last_step_pos, last_env_ids)
826838
if alf.summary.should_record_summaries():
@@ -829,22 +841,24 @@ def transform_experience(self, experience: Experience):
829841
torch.mean(dist.type(torch.float32)))
830842

831843
# get random future state
832-
future_idx = last_step_pos + (torch.rand(*dist.shape) *
833-
(dist + 1)).to(torch.int64)
844+
future_dist = (torch.rand(*dist.shape) * (dist + 1)).to(
845+
torch.int64)
846+
future_idx = last_step_pos + future_dist
834847
future_ag = buffer.get_field(self._achieved_goal_field,
835848
last_env_ids, future_idx).unsqueeze(1)
836849

837850
# relabel desired goal
838851
result_desired_goal = alf.nest.get_field(result,
839852
self._desired_goal_field)
840-
relabed_goal = result_desired_goal.clone()
853+
relabeled_goal = result_desired_goal.clone()
841854
her_batch_index_tuple = (her_indices.unsqueeze(1),
842855
torch.arange(batch_length).unsqueeze(0))
843-
relabed_goal[her_batch_index_tuple] = future_ag
856+
if has_her:
857+
relabeled_goal[her_batch_index_tuple] = future_ag[her_indices]
844858

845859
# recompute rewards
846860
result_ag = alf.nest.get_field(result, self._achieved_goal_field)
847-
relabeled_rewards = self._reward_fn(result_ag, relabed_goal)
861+
relabeled_rewards = self._reward_fn(result_ag, relabeled_goal)
848862

849863
non_her_or_fst = ~her_cond.unsqueeze(1) & (result.step_type !=
850864
StepType.FIRST)
@@ -874,21 +888,28 @@ def transform_experience(self, experience: Experience):
874888
alf.summary.scalar(
875889
"replayer/" + buffer._name + ".reward_mean_before_relabel",
876890
torch.mean(result.reward[her_indices][:-1]))
877-
alf.summary.scalar(
878-
"replayer/" + buffer._name + ".reward_mean_after_relabel",
879-
torch.mean(relabeled_rewards[her_indices][:-1]))
891+
if has_her:
892+
alf.summary.scalar(
893+
"replayer/" + buffer._name + ".reward_mean_after_relabel",
894+
torch.mean(relabeled_rewards[her_indices][:-1]))
895+
alf.summary.scalar("replayer/" + buffer._name + ".future_distance",
896+
torch.mean(future_dist.float()))
880897

881898
result = alf.nest.transform_nest(
882-
result, self._desired_goal_field, lambda _: relabed_goal)
883-
899+
result, self._desired_goal_field, lambda _: relabeled_goal)
884900
result = result.update_time_step_field('reward', relabeled_rewards)
885-
901+
info = info._replace(her=her_cond, future_distance=future_dist)
886902
if alf.get_default_device() != buffer.device:
887903
for f in accessed_fields:
888904
result = alf.nest.transform_nest(
889905
result, f, lambda t: convert_device(t))
890-
result = alf.nest.transform_nest(
891-
result, "batch_info.replay_buffer", lambda _: buffer)
906+
info = convert_device(info)
907+
info = info._replace(
908+
her=info.her.unsqueeze(1).expand(result.reward.shape[:2]),
909+
future_distance=info.future_distance.unsqueeze(1).expand(
910+
result.reward.shape[:2]),
911+
replay_buffer=buffer)
912+
result = alf.data_structures.add_batch_info(result, info)
892913
return result
893914

894915

alf/algorithms/ddpg_algorithm.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,20 @@
4040
DdpgActorState = namedtuple("DdpgActorState", ['actor', 'critics'])
4141
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
4242
DdpgInfo = namedtuple(
43-
"DdpgInfo", [
44-
"reward", "step_type", "discount", "action", "action_distribution",
45-
"actor_loss", "critic", "discounted_return"
43+
"DdpgInfo",
44+
[
45+
"reward",
46+
"step_type",
47+
"discount",
48+
"action",
49+
"action_distribution",
50+
"actor_loss",
51+
"critic",
52+
# Optional fields for value target lower bounding or Hindsight relabeling.
53+
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
54+
"discounted_return",
55+
"future_distance",
56+
"her"
4657
],
4758
default_value=())
4859
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))

alf/algorithms/one_step_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from typing import Union, List, Callable
1717

1818
import alf
19-
from alf.algorithms.td_loss import TDLoss, TDQRLoss
19+
from alf.algorithms.td_loss import LowerBoundedTDLoss, TDQRLoss
2020
from alf.utils import losses
2121

2222

2323
@alf.configurable
24-
class OneStepTDLoss(TDLoss):
24+
class OneStepTDLoss(LowerBoundedTDLoss):
2525
def __init__(self,
2626
gamma: Union[float, List[float]] = 0.99,
2727
td_error_loss_fn: Callable = losses.element_wise_squared_loss,

alf/algorithms/sac_algorithm.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,22 @@
5454
"SacActorInfo", ["actor_loss", "neg_entropy"], default_value=())
5555

5656
SacInfo = namedtuple(
57-
"SacInfo", [
58-
"reward", "step_type", "discount", "action", "action_distribution",
59-
"actor", "critic", "alpha", "log_pi", "discounted_return"
57+
"SacInfo",
58+
[
59+
"reward",
60+
"step_type",
61+
"discount",
62+
"action",
63+
"action_distribution",
64+
"actor",
65+
"critic",
66+
"alpha",
67+
"log_pi",
68+
# Optional fields for value target lower bounding or Hindsight relabeling.
69+
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
70+
"discounted_return",
71+
"future_distance",
72+
"her"
6073
],
6174
default_value=())
6275

0 commit comments

Comments
 (0)