From 4285536223705ade783502528914309ca142d959 Mon Sep 17 00:00:00 2001
From: odersky <odersky@gmail.com>
Date: Tue, 17 Dec 2024 17:56:33 +0100
Subject: [PATCH 1/2] Refactor handling of rechecked types

 - Always store new types on rechecking
 - Store them in a hashmap which is associated with the rechecker of the
   current compilation unit
 - After rechecking is done, the map is forgotten, unless keepTypes is true.
   Under keepTypes, then map is kept in an attachment of the unit's root tree.

Change in nomenclature:

    knownType --> nuType
    rememberType --> setNuType
    hasRememberedType --> hasNuType
---
 .../dotty/tools/dotc/cc/CheckCaptures.scala   |  46 +++++---
 compiler/src/dotty/tools/dotc/cc/Setup.scala  |  47 ++++----
 .../dotty/tools/dotc/transform/Recheck.scala  | 101 +++++++++---------
 3 files changed, 105 insertions(+), 89 deletions(-)

diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
index 23e7d8f8ecf8..830d9ad0a4d4 100644
--- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
+++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
@@ -223,6 +223,22 @@ object CheckCaptures:
       checkNotUniversal.traverse(tpe.widen)
   end checkNotUniversalInUnboxedResult
 
+  trait CheckerAPI:
+    /** Complete symbol info of a val or a def */
+    def completeDef(tree: ValOrDefDef, sym: Symbol)(using Context): Type
+
+    extension [T <: Tree](tree: T)
+
+      /** Set new type of the tree if none was installed yet. */
+      def setNuType(tpe: Type): Unit
+
+      /** The new type of the tree, or if none was installed, the original type */
+      def nuType(using Context): Type
+
+      /** Was a new type installed for this tree? */
+      def hasNuType: Boolean
+  end CheckerAPI
+
 class CheckCaptures extends Recheck, SymTransformer:
   thisPhase =>
 
@@ -243,7 +259,7 @@ class CheckCaptures extends Recheck, SymTransformer:
 
   val ccState1 = new CCState // Dotty problem: Rename to ccState ==> Crash in ExplicitOuter
 
-  class CaptureChecker(ictx: Context) extends Rechecker(ictx):
+  class CaptureChecker(ictx: Context) extends Rechecker(ictx), CheckerAPI:
 
     /** The current environment */
     private val rootEnv: Env = inContext(ictx):
@@ -261,10 +277,6 @@ class CheckCaptures extends Recheck, SymTransformer:
      */
     private val todoAtPostCheck = new mutable.ListBuffer[() => Unit]
 
-    override def keepType(tree: Tree) =
-      super.keepType(tree)
-      || tree.isInstanceOf[Try]  // type of `try` needs tp be checked for * escapes
-
     /** Instantiate capture set variables appearing contra-variantly to their
      *  upper approximation.
      */
@@ -286,8 +298,8 @@ class CheckCaptures extends Recheck, SymTransformer:
      */
     private def interpolateVarsIn(tpt: Tree)(using Context): Unit =
       if tpt.isInstanceOf[InferredTypeTree] then
-        interpolator().traverse(tpt.knownType)
-          .showing(i"solved vars in ${tpt.knownType}", capt)
+        interpolator().traverse(tpt.nuType)
+          .showing(i"solved vars in ${tpt.nuType}", capt)
       for msg <- ccState.approxWarnings do
         report.warning(msg, tpt.srcPos)
       ccState.approxWarnings.clear()
@@ -501,11 +513,11 @@ class CheckCaptures extends Recheck, SymTransformer:
             then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos)
             else if arg.span.exists then ("", arg.srcPos)
             else ("", fn.srcPos)
-          disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
+          disallowRootCapabilitiesIn(arg.nuType, NoSymbol,
             i"Type variable $pname of $sym", "be instantiated to", addendum, pos)
 
           val param = fn.symbol.paramNamed(pname)
-          if param.isUseParam then markFree(arg.knownType.deepCaptureSet, pos)
+          if param.isUseParam then markFree(arg.nuType.deepCaptureSet, pos)
     end disallowCapInTypeArgs
 
     override def recheckIdent(tree: Ident, pt: Type)(using Context): Type =
@@ -769,8 +781,8 @@ class CheckCaptures extends Recheck, SymTransformer:
      */
     def checkContains(tree: TypeApply)(using Context): Unit = tree match
       case ContainsImpl(csArg, refArg) =>
-        val cs = csArg.knownType.captureSet
-        val ref = refArg.knownType
+        val cs = csArg.nuType.captureSet
+        val ref = refArg.nuType
         capt.println(i"check contains $cs , $ref")
         ref match
           case ref: CaptureRef if ref.isTracked =>
@@ -852,7 +864,7 @@ class CheckCaptures extends Recheck, SymTransformer:
               case _ =>
                 (sym, "")
             disallowRootCapabilitiesIn(
-              tree.tpt.knownType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos)
+              tree.tpt.nuType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos)
           checkInferredResult(super.recheckValDef(tree, sym), tree)
       finally
         if !sym.is(Param) then
@@ -1533,7 +1545,7 @@ class CheckCaptures extends Recheck, SymTransformer:
     private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup]
 
     override def checkUnit(unit: CompilationUnit)(using Context): Unit =
-      setup.setupUnit(unit.tpdTree, completeDef)
+      setup.setupUnit(unit.tpdTree, this)
       collectCapturedMutVars.traverse(unit.tpdTree)
 
       if ctx.settings.YccPrintSetup.value then
@@ -1676,7 +1688,7 @@ class CheckCaptures extends Recheck, SymTransformer:
               traverseChildren(tp)
 
       if tree.isInstanceOf[InferredTypeTree] then
-        checker.traverse(tree.knownType)
+        checker.traverse(tree.nuType)
     end healTypeParam
 
     /** Under the unsealed policy: Arrays are like vars, check that their element types
@@ -1716,10 +1728,10 @@ class CheckCaptures extends Recheck, SymTransformer:
             check(tree)
         def check(tree: Tree)(using Context) = tree match
           case TypeApply(fun, args) =>
-            fun.knownType.widen match
+            fun.nuType.widen match
               case tl: PolyType =>
                 val normArgs = args.lazyZip(tl.paramInfos).map: (arg, bounds) =>
-                  arg.withType(arg.knownType.forceBoxStatus(
+                  arg.withType(arg.nuType.forceBoxStatus(
                     bounds.hi.isBoxedCapturing | bounds.lo.isBoxedCapturing))
                 checkBounds(normArgs, tl)
                 args.lazyZip(tl.paramNames).foreach(healTypeParam(_, _, fun.symbol))
@@ -1739,7 +1751,7 @@ class CheckCaptures extends Recheck, SymTransformer:
           def traverse(t: Tree)(using Context) = t match
             case tree: InferredTypeTree =>
             case tree: New =>
-            case tree: TypeTree => checkAppliedTypesIn(tree.withKnownType)
+            case tree: TypeTree => checkAppliedTypesIn(tree.withType(tree.nuType))
             case _ => traverseChildren(t)
         checkApplied.traverse(unit)
     end postCheck
diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala
index c5c362dbe8dc..3ce68792088a 100644
--- a/compiler/src/dotty/tools/dotc/cc/Setup.scala
+++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala
@@ -19,6 +19,7 @@ import printing.{Printer, Texts}, Texts.{Text, Str}
 import collection.mutable
 import CCState.*
 import dotty.tools.dotc.util.NoSourcePosition
+import CheckCaptures.CheckerAPI
 
 /** Operations accessed from CheckCaptures */
 trait SetupAPI:
@@ -28,10 +29,9 @@ trait SetupAPI:
 
   /** Setup procedure to run for each compilation unit
    *   @param tree       the typed tree of the unit to check
-   *   @param recheckDef the recheck method to run on completion of symbols with
-   *                     inferred (result-) types
+   *   @param checker    the capture checker which will run subsequently.
    */
-  def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit
+  def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit
 
   /** Symbol is a term member of a class that was not capture checked
    *  The info of these symbols is made fluid.
@@ -378,15 +378,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
     tp2
   end transformExplicitType
 
-  /** Transform type of tree, and remember the transformed type as the type the tree */
-  private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit =
-    if !tree.hasRememberedType then
-      val transformed =
-        if tree.isInferred
-        then transformInferredType(tree.tpe)
-        else transformExplicitType(tree.tpe, tptToCheck = tree)
-      tree.rememberType(if boxed then box(transformed) else transformed)
-
   /** Substitute parameter symbols in `from` to paramRefs in corresponding
    *  method or poly types `to`. We use a single BiTypeMap to do everything.
    *  @param from  a list of lists of type or term parameter symbols of a curried method
@@ -436,7 +427,17 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
     atPhase(thisPhase.next)(sym.info)
 
   /** A traverser that adds knownTypes and updates symbol infos */
-  def setupTraverser(recheckDef: DefRecheck) = new TreeTraverserWithPreciseImportContexts:
+  def setupTraverser(checker: CheckerAPI) = new TreeTraverserWithPreciseImportContexts:
+    import checker.*
+
+    /** Transform type of tree, and remember the transformed type as the type the tree */
+    private def transformTT(tree: TypeTree, boxed: Boolean)(using Context): Unit =
+      if !tree.hasNuType then
+        val transformed =
+          if tree.isInferred
+          then transformInferredType(tree.tpe)
+          else transformExplicitType(tree.tpe, tptToCheck = tree)
+        tree.setNuType(if boxed then box(transformed) else transformed)
 
     /** Transform the type of a val or var or the result type of a def */
     def transformResultType(tpt: TypeTree, sym: Symbol)(using Context): Unit =
@@ -464,7 +465,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
             traverse(parent)
           case _ =>
             traverseChildren(tp)
-      addDescription.traverse(tpt.knownType)
+      addDescription.traverse(tpt.nuType)
     end transformResultType
 
     def traverse(tree: Tree)(using Context): Unit =
@@ -504,7 +505,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
 
         case tree @ SeqLiteral(elems, tpt: TypeTree) =>
           traverse(elems)
-          tpt.rememberType(box(transformInferredType(tpt.tpe)))
+          tpt.setNuType(box(transformInferredType(tpt.tpe)))
 
         case tree: Block =>
           inNestedLevel(traverseChildren(tree))
@@ -537,22 +538,22 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
         // with special treatment for constructors.
         def localReturnType =
           if sym.isConstructor then constrReturnType(sym.info, sym.paramSymss)
-          else tree.tpt.knownType
+          else tree.tpt.nuType
 
         // A test whether parameter signature might change. This returns true if one of
-        // the parameters has a remembered type. The idea here is that we store a remembered
+        // the parameters has a new type installee. The idea here is that we store a new
         // type only if the transformed type is different from the original.
         def paramSignatureChanges = tree.match
           case tree: DefDef =>
             tree.paramss.nestedExists:
-              case param: ValDef => param.tpt.hasRememberedType
-              case param: TypeDef => param.rhs.hasRememberedType
+              case param: ValDef => param.tpt.hasNuType
+              case param: TypeDef => param.rhs.hasNuType
           case _ => false
 
         // A symbol's signature changes if some of its parameter types or its result type
         // have a new type installed here (meaning hasRememberedType is true)
         def signatureChanges =
-          tree.tpt.hasRememberedType && !sym.isConstructor || paramSignatureChanges
+          tree.tpt.hasNuType && !sym.isConstructor || paramSignatureChanges
 
         // Replace an existing symbol info with inferred types where capture sets of
         // TypeParamRefs and TermParamRefs are put in correspondence by BiTypeMaps with the
@@ -616,7 +617,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
                   capt.println(i"forcing $sym, printing = ${ctx.mode.is(Mode.Printing)}")
                   //if ctx.mode.is(Mode.Printing) then new Error().printStackTrace()
                   denot.info = newInfo
-                  recheckDef(tree, sym)
+                  completeDef(tree, sym)
             updateInfo(sym, updatedInfo)
 
       case tree: Bind =>
@@ -833,8 +834,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
   /** Run setup on a compilation unit with given `tree`.
    *  @param recheckDef   the function to run for completing a val or def
    */
-  def setupUnit(tree: Tree, recheckDef: DefRecheck)(using Context): Unit =
-    setupTraverser(recheckDef).traverse(tree)(using ctx.withPhase(thisPhase))
+  def setupUnit(tree: Tree, checker: CheckerAPI)(using Context): Unit =
+    setupTraverser(checker).traverse(tree)(using ctx.withPhase(thisPhase))
 
   // ------ Checks to run after main capture checking --------------------------
 
diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala
index d3173cef252d..172ae337d6e6 100644
--- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala
+++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala
@@ -28,18 +28,29 @@ import dotty.tools.dotc.cc.boxed
 object Recheck:
   import tpd.*
 
-  /** Attachment key for rechecked types of TypeTrees */
-  val RecheckedType = Property.Key[Type]
-
-  val addRecheckedTypes = new TreeMap:
-    override def transform(tree: Tree)(using Context): Tree =
-      try
-        val tree1 = super.transform(tree)
-        tree.getAttachment(RecheckedType) match
-          case Some(tpe) => tree1.withType(tpe)
-          case None => tree1
-      catch
-        case _:TypeError => tree
+  /** Attachment key for a toplevel tree of a unit that contains a map
+   *  from nodes in that tree to their rechecked "new" types
+   */
+  val RecheckedTypes = Property.Key[util.EqHashMap[Tree, Type]]
+
+  /** If tree carries a RecheckedTypes attachment, use the associated `nuTypes`
+   *  map to produce a new tree that contains at each node the type in the
+   *  map as the node's .tpe field
+   */
+  def addRecheckedTypes(tree: Tree)(using Context): Tree =
+    tree.getAttachment(RecheckedTypes) match
+      case Some(nuTypes) =>
+        val withNuTypes = new TreeMap:
+          override def transform(tree: Tree)(using Context): Tree =
+            try
+              val tree1 = super.transform(tree)
+              val tpe = nuTypes.lookup(tree)
+              if tpe != null then tree1.withType(tpe) else tree1
+            catch
+              case _: TypeError => tree
+        withNuTypes.transform(tree)
+      case None =>
+        tree
 
   extension (sym: Symbol)(using Context)
 
@@ -61,30 +72,6 @@ object Recheck:
       val symd = sym.denot
       symd.validFor.firstPhaseId == phase.id + 1 && (sym.originDenotation ne symd)
 
-  extension [T <: Tree](tree: T)
-
-    /** Remember `tpe` as the type of `tree`, which might be different from the
-     *  type stored in the tree itself, unless a type was already remembered for `tree`.
-     */
-    def rememberType(tpe: Type)(using Context): Unit =
-      if !tree.hasAttachment(RecheckedType) then rememberTypeAlways(tpe)
-
-    /** Remember `tpe` as the type of `tree`, which might be different from the
-     *  type stored in the tree itself
-     */
-    def rememberTypeAlways(tpe: Type)(using Context): Unit =
-      if tpe ne tree.knownType then tree.putAttachment(RecheckedType, tpe)
-
-    /** The remembered type of the tree, or if none was installed, the original type */
-    def knownType: Type =
-      tree.attachmentOrElse(RecheckedType, tree.tpe)
-
-    def hasRememberedType: Boolean = tree.hasAttachment(RecheckedType)
-
-    def withKnownType(using Context): T = tree.getAttachment(RecheckedType) match
-      case Some(tpe) => tree.withType(tpe).asInstanceOf[T]
-      case None => tree
-
   /** Map ExprType => T to () ?=> T (and analogously for pure versions).
    *  Even though this phase runs after ElimByName, ExprTypes can still occur
    *  as by-name arguments of applied types. See note in doc comment for
@@ -172,17 +159,32 @@ abstract class Recheck extends Phase, SymTransformer:
   class Rechecker(@constructorOnly ictx: Context):
     private val ta = ictx.typeAssigner
 
-    /** If true, remember types of all tree nodes in attachments so that they
-     *  can be retrieved with `knownType`
-     */
-    private val keepAllTypes = inContext(ictx) {
-      ictx.settings.Xprint.value.containsPhase(thisPhase)
-    }
+    private val nuTypes = util.EqHashMap[Tree, Type]()
+
+    extension [T <: Tree](tree: T)
+
+      /** Set new type of the tree if none was installed yet and the new type is different
+       *  from the current type.
+       */
+      def setNuType(tpe: Type): Unit =
+        if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then nuTypes(tree) = tpe
+
+      /** The new type of the tree, or if none was installed, the original type */
+      def nuType(using Context): Type =
+        val ntpe = nuTypes.lookup(tree)
+        if ntpe != null then ntpe else tree.tpe
+
+      /** Was a new type installed for this tree? */
+      def hasNuType: Boolean =
+        nuTypes.lookup(tree) != null
+    end extension
 
-    /** Should type of `tree` be kept in an attachment so that it can be retrieved with
-     *  `knownType`? By default true only is `keepAllTypes` hold, but can be overridden.
+    /** If true, remember the new types of nodes in this compilation unit
+     *  as an attachment in the unit's tpdTree node. By default, this is
+     *  enabled when -Xprint:cc is set. Can be overridden.
      */
-    def keepType(tree: Tree): Boolean = keepAllTypes
+    def keepNuTypes(using Context): Boolean =
+      ctx.settings.Xprint.value.containsPhase(thisPhase)
 
     /** A map from NamedTypes to the denotations they had before this phase.
      *  Needed so that we can `reset` them after this phase.
@@ -343,7 +345,6 @@ abstract class Recheck extends Phase, SymTransformer:
 
     def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type =
       val funtpe = recheck(tree.fun)
-      tree.fun.rememberType(funtpe) // remember type to support later bounds checks
       funtpe.widen match
         case fntpe: PolyType =>
           assert(fntpe.paramInfos.hasSameLengthAs(tree.args))
@@ -459,7 +460,7 @@ abstract class Recheck extends Phase, SymTransformer:
       seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes))
 
     def recheckTypeTree(tree: TypeTree)(using Context): Type =
-      tree.knownType  // allows to install new types at Setup
+      tree.nuType  // allows to install new types at Setup
 
     def recheckAnnotated(tree: Annotated)(using Context): Type =
       tree.tpe match
@@ -558,7 +559,7 @@ abstract class Recheck extends Phase, SymTransformer:
      */
     def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type =
       val tpe1 = checkConforms(tpe, pt, tree)
-      if keepType(tree) then tree.rememberType(tpe1)
+      tree.setNuType(tpe1)
       tpe1
 
     def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
@@ -617,6 +618,7 @@ abstract class Recheck extends Phase, SymTransformer:
 
     def checkUnit(unit: CompilationUnit)(using Context): Unit =
       recheck(unit.tpdTree)
+      if keepNuTypes then unit.tpdTree.putAttachment(RecheckedTypes, nuTypes)
 
   end Rechecker
 
@@ -624,7 +626,8 @@ abstract class Recheck extends Phase, SymTransformer:
   override def show(tree: untpd.Tree)(using Context): String =
     atPhase(thisPhase):
       withMode(Mode.Printing):
-        super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree]))
+        super.show:
+          addRecheckedTypes(tree.asInstanceOf[tpd.Tree])
 end Recheck
 
 /** A class that can be used to test basic rechecking without any customaization */

From 5ac4d735b75a3b3efe42542ea63cdcb396a2f5ab Mon Sep 17 00:00:00 2001
From: odersky <odersky@gmail.com>
Date: Mon, 23 Dec 2024 18:31:04 +0100
Subject: [PATCH 2/2] Drop unused type alias

---
 compiler/src/dotty/tools/dotc/cc/Setup.scala | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala
index 3ce68792088a..ebe128d7776c 100644
--- a/compiler/src/dotty/tools/dotc/cc/Setup.scala
+++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala
@@ -24,9 +24,6 @@ import CheckCaptures.CheckerAPI
 /** Operations accessed from CheckCaptures */
 trait SetupAPI:
 
-  /** The operation to recheck a ValDef or DefDef */
-  type DefRecheck = (tpd.ValOrDefDef, Symbol) => Context ?=> Type
-
   /** Setup procedure to run for each compilation unit
    *   @param tree       the typed tree of the unit to check
    *   @param checker    the capture checker which will run subsequently.