From f3264e412e5348e148c9e6a7773367665cf46f95 Mon Sep 17 00:00:00 2001
From: Jonas Schievink <jonasschievink@gmail.com>
Date: Sun, 31 May 2020 03:36:46 +0200
Subject: [PATCH] Validate that all accessed locals are initialized

---
 .../dataflow/impls/init_locals.rs             | 34 +++++++--
 src/librustc_mir/transform/generator.rs       |  2 +-
 src/librustc_mir/transform/validate.rs        | 73 ++++++++++++++++++-
 3 files changed, 99 insertions(+), 10 deletions(-)

diff --git a/src/librustc_mir/dataflow/impls/init_locals.rs b/src/librustc_mir/dataflow/impls/init_locals.rs
index 726330b1f035e..13459f22a054e 100644
--- a/src/librustc_mir/dataflow/impls/init_locals.rs
+++ b/src/librustc_mir/dataflow/impls/init_locals.rs
@@ -8,7 +8,24 @@ use rustc_index::bit_set::BitSet;
 use rustc_middle::mir::visit::{PlaceContext, Visitor};
 use rustc_middle::mir::{self, BasicBlock, Local, Location};
 
-pub struct MaybeInitializedLocals;
+pub struct MaybeInitializedLocals {
+    deinit_on_move: bool,
+}
+
+impl MaybeInitializedLocals {
+    /// Creates a new default `MaybeInitializedLocals` analysis.
+    pub fn new() -> Self {
+        Self { deinit_on_move: true }
+    }
+
+    /// Creates a new `MaybeInitializedLocals` analysis that does not consider a move from a local
+    /// to deinitialize it.
+    ///
+    /// `StorageDead` still deinitializes locals.
+    pub fn no_deinit_on_move() -> Self {
+        Self { deinit_on_move: false }
+    }
+}
 
 impl BottomValue for MaybeInitializedLocals {
     /// bottom = uninit
@@ -42,7 +59,8 @@ impl dataflow::GenKillAnalysis<'tcx> for MaybeInitializedLocals {
         statement: &mir::Statement<'tcx>,
         loc: Location,
     ) {
-        TransferFunction { trans }.visit_statement(statement, loc)
+        TransferFunction { trans, deinit_on_move: self.deinit_on_move }
+            .visit_statement(statement, loc)
     }
 
     fn terminator_effect(
@@ -51,7 +69,8 @@ impl dataflow::GenKillAnalysis<'tcx> for MaybeInitializedLocals {
         terminator: &mir::Terminator<'tcx>,
         loc: Location,
     ) {
-        TransferFunction { trans }.visit_terminator(terminator, loc)
+        TransferFunction { trans, deinit_on_move: self.deinit_on_move }
+            .visit_terminator(terminator, loc)
     }
 
     fn call_return_effect(
@@ -77,6 +96,7 @@ impl dataflow::GenKillAnalysis<'tcx> for MaybeInitializedLocals {
 }
 
 struct TransferFunction<'a, T> {
+    deinit_on_move: bool,
     trans: &'a mut T,
 }
 
@@ -95,8 +115,12 @@ where
 
             // If the local is moved out of, or if it gets marked `StorageDead`, consider it no
             // longer initialized.
-            PlaceContext::NonUse(NonUseContext::StorageDead)
-            | PlaceContext::NonMutatingUse(NonMutatingUseContext::Move) => self.trans.kill(local),
+            PlaceContext::NonUse(NonUseContext::StorageDead) => self.trans.kill(local),
+            PlaceContext::NonMutatingUse(NonMutatingUseContext::Move) => {
+                if self.deinit_on_move {
+                    self.trans.kill(local)
+                }
+            }
 
             // All other uses do not affect this analysis.
             PlaceContext::NonUse(
diff --git a/src/librustc_mir/transform/generator.rs b/src/librustc_mir/transform/generator.rs
index 461b13c4f6382..c1a54f55b74f8 100644
--- a/src/librustc_mir/transform/generator.rs
+++ b/src/librustc_mir/transform/generator.rs
@@ -452,7 +452,7 @@ fn locals_live_across_suspend_points(
         .iterate_to_fixpoint()
         .into_results_cursor(body);
 
-    let mut init = MaybeInitializedLocals
+    let mut init = MaybeInitializedLocals::new()
         .into_engine(tcx, body, def_id)
         .iterate_to_fixpoint()
         .into_results_cursor(body);
diff --git a/src/librustc_mir/transform/validate.rs b/src/librustc_mir/transform/validate.rs
index a25edd131baa1..3c900d9cd5702 100644
--- a/src/librustc_mir/transform/validate.rs
+++ b/src/librustc_mir/transform/validate.rs
@@ -1,12 +1,16 @@
 //! Validates the MIR to ensure that invariants are upheld.
 
 use super::{MirPass, MirSource};
-use rustc_middle::mir::visit::Visitor;
+use crate::dataflow::{impls::MaybeInitializedLocals, Analysis, ResultsCursor};
+use rustc_index::bit_set::BitSet;
+use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor};
+use rustc_middle::ty;
 use rustc_middle::{
-    mir::{Body, Location, Operand, Rvalue, Statement, StatementKind},
+    mir::{traversal, Body, Local, Location, Operand, Rvalue, Statement, StatementKind},
     ty::{ParamEnv, TyCtxt},
 };
 use rustc_span::{def_id::DefId, Span, DUMMY_SP};
+use ty::Ty;
 
 pub struct Validator {
     /// Describes at which point in the pipeline this validation is happening.
@@ -17,16 +21,32 @@ impl<'tcx> MirPass<'tcx> for Validator {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
         let def_id = source.def_id();
         let param_env = tcx.param_env(def_id);
-        TypeChecker { when: &self.when, def_id, body, tcx, param_env }.visit_body(body);
+
+        // Do not consider moves to deinitialize locals. Some MIR passes output MIR that violates
+        // this assumption and would lead to uses of uninitialized data.
+        let init = MaybeInitializedLocals::no_deinit_on_move()
+            .into_engine(tcx, body, def_id)
+            .iterate_to_fixpoint()
+            .into_results_cursor(body);
+
+        let mut checker = TypeChecker { when: &self.when, def_id, body, tcx, param_env, init };
+
+        // Only visit reachable blocks. Unreachable code may access uninitialized locals.
+        for (block, data) in traversal::preorder(body) {
+            checker.visit_basic_block_data(block, data);
+        }
     }
 }
 
 struct TypeChecker<'a, 'tcx> {
     when: &'a str,
+
     def_id: DefId,
     body: &'a Body<'tcx>,
     tcx: TyCtxt<'tcx>,
     param_env: ParamEnv<'tcx>,
+
+    init: ResultsCursor<'a, 'tcx, MaybeInitializedLocals>,
 }
 
 impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
@@ -46,7 +66,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
         if let Operand::Copy(place) = operand {
             let ty = place.ty(&self.body.local_decls, self.tcx).ty;
 
-            if !ty.is_copy_modulo_regions(self.tcx, self.param_env, DUMMY_SP) {
+            if false && !ty.is_copy_modulo_regions(self.tcx, self.param_env, DUMMY_SP) {
                 self.fail(
                     DUMMY_SP,
                     format!("`Operand::Copy` with non-`Copy` type {} at {:?}", ty, location),
@@ -76,5 +96,50 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
                 _ => {}
             }
         }
+
+        // Every local used by a statement must be initialized before the statement executes.
+        self.init.seek_before_primary_effect(location);
+        UsedLocalsAreInitialized {
+            checker: self,
+            init: self.init.get(),
+            span: statement.source_info.span,
+        }
+        .visit_statement(statement, location);
+    }
+}
+
+struct UsedLocalsAreInitialized<'a, 'tcx> {
+    checker: &'a TypeChecker<'a, 'tcx>,
+    init: &'a BitSet<Local>,
+    span: Span,
+}
+
+impl Visitor<'tcx> for UsedLocalsAreInitialized<'a, 'tcx> {
+    fn visit_local(&mut self, local: &Local, context: PlaceContext, location: Location) {
+        if context.is_use() && !context.is_place_assignment() && !self.init.contains(*local) {
+            if context == PlaceContext::MutatingUse(MutatingUseContext::Projection) {
+                // Ignore `_1.b`-like projections as they appear as assignment destinations, and
+                // `_1` doesn't have to be initialized there.
+                return;
+            }
+
+            if is_zst(
+                self.checker.tcx,
+                self.checker.def_id,
+                self.checker.body.local_decls[*local].ty,
+            ) {
+                // Known ZSTs don't have to be initialized at all, skip them.
+                return;
+            }
+
+            self.checker.fail(
+                self.span,
+                format!("use of uninitialized local {:?} at {:?}", local, location),
+            );
+        }
     }
 }
+
+fn is_zst<'tcx>(tcx: TyCtxt<'tcx>, did: DefId, ty: Ty<'tcx>) -> bool {
+    tcx.layout_of(tcx.param_env(did).and(ty)).map(|layout| layout.is_zst()).unwrap_or(false)
+}