diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index e3a8edccc3a44..9e45320e591a1 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -286,6 +286,9 @@ Under the hood, Lightning does the following (pseudocode): # ... if validate_at_some_point: + # capture .training mode of every submodule + capture_training_mode() + # disable grads + batchnorm + dropout torch.set_grad_enabled(False) model.eval() @@ -295,9 +298,11 @@ Under the hood, Lightning does the following (pseudocode): val_out = model.validation_step(val_batch, val_batch_idx) # ----------------- VAL LOOP --------------- - # enable grads + batchnorm + dropout + # enable grads torch.set_grad_enabled(True) - model.train() + + # restore .training mode of every submodule + restore_training_mode() You can also run just the validation loop on your validation dataloaders by overriding :meth:`~lightning.pytorch.core.LightningModule.validation_step` and calling :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`. @@ -368,7 +373,7 @@ The only difference is that the test loop is only called when :meth:`~lightning. trainer = L.Trainer() trainer.fit(model=model, train_dataloaders=dataloader) - # automatically loads the best weights for you + # use the current weights trainer.test(model) There are two ways to call ``test()``: diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..4764aad5ebaa3 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -814,9 +814,10 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): If you don't need to validate you don't need to implement this method. Note: - When the :meth:`validation_step` is called, the model has been put in eval mode - and PyTorch gradients have been disabled. At the end of validation, - the model goes back to training mode and gradients are enabled. + When the :meth:`validation_step` is called, the model has been put + in eval mode and PyTorch gradients have been disabled. At the end + of the validation epoch, the ``.training`` mode of every submodule + is restored to what it was before and gradients are enabled. """ @@ -881,9 +882,10 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): If you don't need to test you don't need to implement this method. Note: - When the :meth:`test_step` is called, the model has been put in eval mode and - PyTorch gradients have been disabled. At the end of the test epoch, the model goes back - to training mode and gradients are enabled. + When the :meth:`test_step` is called, the model has been put in + eval mode and PyTorch gradients have been disabled. At the end of + the test epoch, the ``.training`` mode of every submodule is + restored to what it was before and gradients are enabled. """ @@ -922,6 +924,12 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm) + Note: + When the :meth:`predict_step` is called, the model has been put in + eval mode and PyTorch gradients have been disabled. At the end of + the predict epoch, the ``.training`` mode of every submodule is + restored to what it was before and gradients are enabled. + """ # For backwards compatibility batch = kwargs.get("batch", args[0])