diff --git a/src/librustc/traits/fulfill.rs b/src/librustc/traits/fulfill.rs
index 0aac6fb81e4a3..07352a3f9478a 100644
--- a/src/librustc/traits/fulfill.rs
+++ b/src/librustc/traits/fulfill.rs
@@ -18,10 +18,13 @@ use super::{FulfillmentError, FulfillmentErrorCode};
 use super::{ObligationCause, PredicateObligation};
 
 impl<'tcx> ForestObligation for PendingPredicateObligation<'tcx> {
-    type Predicate = ty::Predicate<'tcx>;
+    /// Note that we include both the `ParamEnv` and the `Predicate`,
+    /// as the `ParamEnv` can influence whether fulfillment succeeds
+    /// or fails.
+    type CacheKey = ty::ParamEnvAnd<'tcx, ty::Predicate<'tcx>>;
 
-    fn as_predicate(&self) -> &Self::Predicate {
-        &self.obligation.predicate
+    fn as_cache_key(&self) -> Self::CacheKey {
+        self.obligation.param_env.and(self.obligation.predicate)
     }
 }
 
diff --git a/src/librustc_data_structures/obligation_forest/graphviz.rs b/src/librustc_data_structures/obligation_forest/graphviz.rs
index ddf89d99621ca..96ee72d187b34 100644
--- a/src/librustc_data_structures/obligation_forest/graphviz.rs
+++ b/src/librustc_data_structures/obligation_forest/graphviz.rs
@@ -51,7 +51,7 @@ impl<'a, O: ForestObligation + 'a> dot::Labeller<'a> for &'a ObligationForest<O>
 
     fn node_label(&self, index: &Self::Node) -> dot::LabelText<'_> {
         let node = &self.nodes[*index];
-        let label = format!("{:?} ({:?})", node.obligation.as_predicate(), node.state.get());
+        let label = format!("{:?} ({:?})", node.obligation.as_cache_key(), node.state.get());
 
         dot::LabelText::LabelStr(label.into())
     }
diff --git a/src/librustc_data_structures/obligation_forest/mod.rs b/src/librustc_data_structures/obligation_forest/mod.rs
index 974d9dcfae408..500ce5c71f37a 100644
--- a/src/librustc_data_structures/obligation_forest/mod.rs
+++ b/src/librustc_data_structures/obligation_forest/mod.rs
@@ -86,9 +86,13 @@ mod graphviz;
 mod tests;
 
 pub trait ForestObligation: Clone + Debug {
-    type Predicate: Clone + hash::Hash + Eq + Debug;
+    type CacheKey: Clone + hash::Hash + Eq + Debug;
 
-    fn as_predicate(&self) -> &Self::Predicate;
+    /// Converts this `ForestObligation` suitable for use as a cache key.
+    /// If two distinct `ForestObligations`s return the same cache key,
+    /// then it must be sound to use the result of processing one obligation
+    /// (e.g. success for error) for the other obligation
+    fn as_cache_key(&self) -> Self::CacheKey;
 }
 
 pub trait ObligationProcessor {
@@ -138,12 +142,12 @@ pub struct ObligationForest<O: ForestObligation> {
     nodes: Vec<Node<O>>,
 
     /// A cache of predicates that have been successfully completed.
-    done_cache: FxHashSet<O::Predicate>,
+    done_cache: FxHashSet<O::CacheKey>,
 
     /// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately,
     /// its contents are not guaranteed to match those of `nodes`. See the
     /// comments in `process_obligation` for details.
-    active_cache: FxHashMap<O::Predicate, usize>,
+    active_cache: FxHashMap<O::CacheKey, usize>,
 
     /// A vector reused in compress(), to avoid allocating new vectors.
     node_rewrites: RefCell<Vec<usize>>,
@@ -157,7 +161,7 @@ pub struct ObligationForest<O: ForestObligation> {
     /// See [this][details] for details.
     ///
     /// [details]: https://github.com/rust-lang/rust/pull/53255#issuecomment-421184780
-    error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::Predicate>>,
+    error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::CacheKey>>,
 }
 
 #[derive(Debug)]
@@ -305,11 +309,12 @@ impl<O: ForestObligation> ObligationForest<O> {
 
     // Returns Err(()) if we already know this obligation failed.
     fn register_obligation_at(&mut self, obligation: O, parent: Option<usize>) -> Result<(), ()> {
-        if self.done_cache.contains(obligation.as_predicate()) {
+        if self.done_cache.contains(&obligation.as_cache_key()) {
+            debug!("register_obligation_at: ignoring already done obligation: {:?}", obligation);
             return Ok(());
         }
 
-        match self.active_cache.entry(obligation.as_predicate().clone()) {
+        match self.active_cache.entry(obligation.as_cache_key().clone()) {
             Entry::Occupied(o) => {
                 let node = &mut self.nodes[*o.get()];
                 if let Some(parent_index) = parent {
@@ -333,7 +338,7 @@ impl<O: ForestObligation> ObligationForest<O> {
                     && self
                         .error_cache
                         .get(&obligation_tree_id)
-                        .map(|errors| errors.contains(obligation.as_predicate()))
+                        .map(|errors| errors.contains(&obligation.as_cache_key()))
                         .unwrap_or(false);
 
                 if already_failed {
@@ -380,7 +385,7 @@ impl<O: ForestObligation> ObligationForest<O> {
         self.error_cache
             .entry(node.obligation_tree_id)
             .or_default()
-            .insert(node.obligation.as_predicate().clone());
+            .insert(node.obligation.as_cache_key().clone());
     }
 
     /// Performs a pass through the obligation list. This must
@@ -618,11 +623,11 @@ impl<O: ForestObligation> ObligationForest<O> {
                     // `self.nodes`. See the comment in `process_obligation`
                     // for more details.
                     if let Some((predicate, _)) =
-                        self.active_cache.remove_entry(node.obligation.as_predicate())
+                        self.active_cache.remove_entry(&node.obligation.as_cache_key())
                     {
                         self.done_cache.insert(predicate);
                     } else {
-                        self.done_cache.insert(node.obligation.as_predicate().clone());
+                        self.done_cache.insert(node.obligation.as_cache_key().clone());
                     }
                     if do_completed == DoCompleted::Yes {
                         // Extract the success stories.
@@ -635,7 +640,7 @@ impl<O: ForestObligation> ObligationForest<O> {
                     // We *intentionally* remove the node from the cache at this point. Otherwise
                     // tests must come up with a different type on every type error they
                     // check against.
-                    self.active_cache.remove(node.obligation.as_predicate());
+                    self.active_cache.remove(&node.obligation.as_cache_key());
                     self.insert_into_error_cache(index);
                     node_rewrites[index] = orig_nodes_len;
                     dead_nodes += 1;
diff --git a/src/librustc_data_structures/obligation_forest/tests.rs b/src/librustc_data_structures/obligation_forest/tests.rs
index e29335aab2808..01652465eea2c 100644
--- a/src/librustc_data_structures/obligation_forest/tests.rs
+++ b/src/librustc_data_structures/obligation_forest/tests.rs
@@ -4,9 +4,9 @@ use std::fmt;
 use std::marker::PhantomData;
 
 impl<'a> super::ForestObligation for &'a str {
-    type Predicate = &'a str;
+    type CacheKey = &'a str;
 
-    fn as_predicate(&self) -> &Self::Predicate {
+    fn as_cache_key(&self) -> Self::CacheKey {
         self
     }
 }