@@ -734,14 +734,25 @@ class HindsightExperienceTransformer(DataTransformer):
734
734
of the current timestep.
735
735
The exact field names can be provided via arguments to the class ``__init__``.
736
736
737
+ NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes
738
+ reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc..
739
+ See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details.
740
+
737
741
To use this class, add it to any existing data transformers, e.g. use this config if
738
742
``ObservationNormalizer`` is an existing data transformer:
739
743
740
744
.. code-block:: python
741
745
742
- ReplayBuffer.keep_episodic_info=True
743
- HindsightExperienceTransformer.her_proportion=0.8
744
- TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer]
746
+ alf.config('ReplayBuffer', keep_episodic_info=True)
747
+ alf.config(
748
+ 'HindsightExperienceTransformer',
749
+ her_proportion=0.8
750
+ )
751
+ alf.config(
752
+ 'TrainerConfig',
753
+ data_transformer_ctor=[
754
+ HindsightExperienceTransformer, ObservationNormalizer
755
+ ])
745
756
746
757
See unit test for more details on behavior.
747
758
"""
@@ -818,9 +829,10 @@ def transform_experience(self, experience: Experience):
818
829
# relabel only these sampled indices
819
830
her_cond = torch .rand (batch_size ) < her_proportion
820
831
(her_indices , ) = torch .where (her_cond )
832
+ has_her = torch .any (her_cond )
821
833
822
- last_step_pos = start_pos [ her_indices ] + batch_length - 1
823
- last_env_ids = env_ids [ her_indices ]
834
+ last_step_pos = start_pos + batch_length - 1
835
+ last_env_ids = env_ids
824
836
# Get x, y indices of LAST steps
825
837
dist = buffer .steps_to_episode_end (last_step_pos , last_env_ids )
826
838
if alf .summary .should_record_summaries ():
@@ -829,22 +841,24 @@ def transform_experience(self, experience: Experience):
829
841
torch .mean (dist .type (torch .float32 )))
830
842
831
843
# get random future state
832
- future_idx = last_step_pos + (torch .rand (* dist .shape ) *
833
- (dist + 1 )).to (torch .int64 )
844
+ future_dist = (torch .rand (* dist .shape ) * (dist + 1 )).to (
845
+ torch .int64 )
846
+ future_idx = last_step_pos + future_dist
834
847
future_ag = buffer .get_field (self ._achieved_goal_field ,
835
848
last_env_ids , future_idx ).unsqueeze (1 )
836
849
837
850
# relabel desired goal
838
851
result_desired_goal = alf .nest .get_field (result ,
839
852
self ._desired_goal_field )
840
- relabed_goal = result_desired_goal .clone ()
853
+ relabeled_goal = result_desired_goal .clone ()
841
854
her_batch_index_tuple = (her_indices .unsqueeze (1 ),
842
855
torch .arange (batch_length ).unsqueeze (0 ))
843
- relabed_goal [her_batch_index_tuple ] = future_ag
856
+ if has_her :
857
+ relabeled_goal [her_batch_index_tuple ] = future_ag [her_indices ]
844
858
845
859
# recompute rewards
846
860
result_ag = alf .nest .get_field (result , self ._achieved_goal_field )
847
- relabeled_rewards = self ._reward_fn (result_ag , relabed_goal )
861
+ relabeled_rewards = self ._reward_fn (result_ag , relabeled_goal )
848
862
849
863
non_her_or_fst = ~ her_cond .unsqueeze (1 ) & (result .step_type !=
850
864
StepType .FIRST )
@@ -874,21 +888,28 @@ def transform_experience(self, experience: Experience):
874
888
alf .summary .scalar (
875
889
"replayer/" + buffer ._name + ".reward_mean_before_relabel" ,
876
890
torch .mean (result .reward [her_indices ][:- 1 ]))
877
- alf .summary .scalar (
878
- "replayer/" + buffer ._name + ".reward_mean_after_relabel" ,
879
- torch .mean (relabeled_rewards [her_indices ][:- 1 ]))
891
+ if has_her :
892
+ alf .summary .scalar (
893
+ "replayer/" + buffer ._name + ".reward_mean_after_relabel" ,
894
+ torch .mean (relabeled_rewards [her_indices ][:- 1 ]))
895
+ alf .summary .scalar ("replayer/" + buffer ._name + ".future_distance" ,
896
+ torch .mean (future_dist .float ()))
880
897
881
898
result = alf .nest .transform_nest (
882
- result , self ._desired_goal_field , lambda _ : relabed_goal )
883
-
899
+ result , self ._desired_goal_field , lambda _ : relabeled_goal )
884
900
result = result .update_time_step_field ('reward' , relabeled_rewards )
885
-
901
+ info = info . _replace ( her = her_cond , future_distance = future_dist )
886
902
if alf .get_default_device () != buffer .device :
887
903
for f in accessed_fields :
888
904
result = alf .nest .transform_nest (
889
905
result , f , lambda t : convert_device (t ))
890
- result = alf .nest .transform_nest (
891
- result , "batch_info.replay_buffer" , lambda _ : buffer )
906
+ info = convert_device (info )
907
+ info = info ._replace (
908
+ her = info .her .unsqueeze (1 ).expand (result .reward .shape [:2 ]),
909
+ future_distance = info .future_distance .unsqueeze (1 ).expand (
910
+ result .reward .shape [:2 ]),
911
+ replay_buffer = buffer )
912
+ result = alf .data_structures .add_batch_info (result , info )
892
913
return result
893
914
894
915
0 commit comments