Source code for pytorch_lightning.trainer.trainer

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# THIS FILE MUST READ EASILY, FOR UNDERSTANDING AND DEBUGGING PURPOSES.
# DO NOT OBSCURE THE TRAINING LOOP
# THIS IS A HARD REQUIREMENT TO CONTRIBUTING TO LIGHTNING
# WE FAVOR READABILITY OVER ENGINEERING-CONSTRUCTS BY DESIGN
# DO NOT REMOVE THIS NOTICE
# - WILLIAM FALCON
"""Trainer to automate the training."""
import logging
import math
import os
import warnings
from datetime import timedelta
from typing import Any, Dict, Iterable, List, Optional, Union
from weakref import proxy

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers.utilities import _log_hyperparams
from pytorch_lightning.loops import _PredictionLoop, _TrainingEpochLoop
from pytorch_lightning.loops.evaluation_loop import _EvaluationLoop
from pytorch_lightning.loops.fit_loop import _FitLoop
from pytorch_lightning.loops.utilities import _parse_loop_limits, _reset_progress
from pytorch_lightning.plugins import PLUGIN_INPUT, PrecisionPlugin
from pytorch_lightning.profilers import Profiler
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.trainer import call, setup
from pytorch_lightning.trainer.configuration_validator import _verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import (
    _AcceleratorConnector,
    _LITERAL_WARN,
    _PRECISION_INPUT,
    _PRECISION_INPUT_STR,
)
from pytorch_lightning.trainer.connectors.callback_connector import _CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import _DataConnector
from pytorch_lightning.trainer.connectors.logger_connector import _LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import _SignalConnector
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.utilities import GradClipAlgorithmType, parsing
from pytorch_lightning.utilities.argparse import _defaults_from_env_vars
from pytorch_lightning.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import isolate_rng
from pytorch_lightning.utilities.types import (
    _EVALUATE_OUTPUT,
    _PREDICT_OUTPUT,
    EVAL_DATALOADERS,
    LRSchedulerConfig,
    TRAIN_DATALOADERS,
)

log = logging.getLogger(__name__)
# warnings to ignore in trainer
warnings.filterwarnings(
    "ignore", message="torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead"
)


[docs] class Trainer: @_defaults_from_env_vars def __init__( self, *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", devices: Union[List[int], str, int] = "auto", num_nodes: int = 1, precision: _PRECISION_INPUT = "32-true", logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, callbacks: Optional[Union[List[Callback], Callback]] = None, fast_dev_run: Union[int, bool] = False, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: int = -1, min_steps: Optional[int] = None, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, limit_train_batches: Optional[Union[int, float]] = None, limit_val_batches: Optional[Union[int, float]] = None, limit_test_batches: Optional[Union[int, float]] = None, limit_predict_batches: Optional[Union[int, float]] = None, overfit_batches: Union[int, float] = 0.0, val_check_interval: Optional[Union[int, float]] = None, check_val_every_n_epoch: Optional[int] = 1, num_sanity_val_steps: Optional[int] = None, log_every_n_steps: Optional[int] = None, enable_checkpointing: Optional[bool] = None, enable_progress_bar: Optional[bool] = None, enable_model_summary: Optional[bool] = None, accumulate_grad_batches: int = 1, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, benchmark: Optional[bool] = None, inference_mode: bool = True, use_distributed_sampler: bool = True, profiler: Optional[Union[Profiler, str]] = None, detect_anomaly: bool = False, barebones: bool = False, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, ) -> None: r"""Customize every aspect of training via flags. Args: accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto") as well as custom accelerator instances. strategy: Supports different training strategies with aliases as well custom strategies. Default: ``"auto"``. devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for automatic selection based on the chosen accelerator. Default: ``"auto"``. num_nodes: Number of GPU nodes for distributed training. Default: ``1``. precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, HPUs or IPUs. Default: ``'32-true'``. logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. ``False`` will disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of he first logger. Default: ``True``. callbacks: Add a callback or list of callbacks. Default: ``None``. fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). Default: ``False``. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs = -1``. min_epochs: Force training for at least these many epochs. Disabled by default (None). max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``. min_steps: Force training for at least these number of steps. Disabled by default (``None``). max_time: Stop training after this amount of time has passed. Disabled by default (``None``). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`. limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). Default: ``1.0``. limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches). Default: ``1.0``. limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches). Default: ``1.0``. limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches). Default: ``1.0``. overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int). Default: ``0.0``. val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or during iteration-based training. Default: ``1.0``. check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` to be an integer value. Default: ``1``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: ``2``. log_every_n_steps: How often to log within steps. Default: ``50``. enable_checkpointing: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. enable_progress_bar: Whether to enable to progress bar by default. Default: ``True``. enable_model_summary: Whether to enable model summarization by default. Default: ``True``. accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer. Default: 1. gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default: ``None``. gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``. deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode (requires PyTorch 1.11+). If not set, defaults to ``False``. Default: ``None``. benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``. Override to manually set a different value. Default: ``None``. inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during evaluation (``validate``/``test``/``predict``). use_distributed_sampler: Whether to wrap the DataLoader's sampler with :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, we don't do this automatically. profiler: To profile individual steps during training and assist in identifying bottlenecks. Default: ``None``. detect_anomaly: Enable anomaly detection for the autograd engine. Default: ``False``. barebones: Whether to run in "barebones mode", where all features that may impact raw speed are disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training runs. The following features are deactivated: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.enable_checkpointing`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.logger`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.enable_progress_bar`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.log_every_n_steps`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.enable_model_summary`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.num_sanity_val_steps`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.detect_anomaly`, :paramref:`~pytorch_lightning.trainer.trainer.Trainer.profiler`, :meth:`~pytorch_lightning.core.module.LightningModule.log`, :meth:`~pytorch_lightning.core.module.LightningModule.log_dict`. plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: ``None``. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. Default: ``False``. reload_dataloaders_every_n_epochs: Set to a non-negative integer to reload dataloaders every n epochs. Default: ``0``. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Raises: TypeError: If ``gradient_clip_val`` is not an int or float. MisconfigurationException: If ``gradient_clip_algorithm`` is invalid. If ``track_grad_norm`` is not a positive number or inf. """ super().__init__() log.debug(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}") self.state = TrainerState() if default_root_dir is not None: default_root_dir = os.fspath(default_root_dir) self.barebones = barebones if barebones: # opt-outs if enable_checkpointing: raise ValueError( f"`Trainer(barebones=True, enable_checkpointing={enable_checkpointing!r})` was passed." " Checkpointing can impact raw speed so it is disabled in barebones mode." ) enable_checkpointing = False if logger is not None and logger is not False: raise ValueError( f"`Trainer(barebones=True, logger={logger!r})` was passed." " Logging can impact raw speed so it is disabled in barebones mode." ) logger = False if enable_progress_bar: raise ValueError( f"`Trainer(barebones=True, enable_progress_bar={enable_progress_bar!r})` was passed." " The progress bar can impact raw speed so it is disabled in barebones mode." ) enable_progress_bar = False if log_every_n_steps is not None and log_every_n_steps != 0: raise ValueError( f"`Trainer(barebones=True, log_every_n_steps={log_every_n_steps!r})` was passed." " Logging can impact raw speed so it is disabled in barebones mode." ) log_every_n_steps = 0 if enable_model_summary: raise ValueError( f"`Trainer(barebones=True, enable_model_summary={enable_model_summary!r})` was passed." " Model summary can impact raw speed so it is disabled in barebones mode." ) enable_model_summary = False if num_sanity_val_steps is not None and num_sanity_val_steps != 0: raise ValueError( f"`Trainer(barebones=True, num_sanity_val_steps={num_sanity_val_steps!r})` was passed." " Sanity checking can impact raw speed so it is disabled in barebones mode." ) num_sanity_val_steps = 0 # opt-ins if fast_dev_run is not False and fast_dev_run != 0: raise ValueError( f"`Trainer(barebones=True, fast_dev_run={fast_dev_run!r})` was passed." " Development run is not meant for raw speed evaluation so it is disabled in barebones mode." ) if detect_anomaly: raise ValueError( f"`Trainer(barebones=True, detect_anomaly={detect_anomaly!r})` was passed." " Anomaly detection can impact raw speed so it is disabled in barebones mode." ) if profiler is not None: raise ValueError( f"`Trainer(barebones=True, profiler={profiler!r})` was passed." " Profiling can impact raw speed so it is disabled in barebones mode." ) deactivated = ( " - Checkpointing: `Trainer(enable_checkpointing=True)`", " - Progress bar: `Trainer(enable_progress_bar=True)`", " - Model summary: `Trainer(enable_model_summary=True)`", " - Logging: `Trainer(logger=True)`, `Trainer(log_every_n_steps>0)`," " `LightningModule.log(...)`, `LightningModule.log_dict(...)`", " - Sanity checking: `Trainer(num_sanity_val_steps>0)`", " - Development run: `Trainer(fast_dev_run=True)`", " - Anomaly detection: `Trainer(detect_anomaly=True)`", " - Profiling: `Trainer(profiler=...)`", ) rank_zero_info( "You are running in `Trainer(barebones=True)` mode. All features that may impact raw speed have been" " disabled to facilitate analyzing the Trainer overhead. Specifically, the following features are" f" deactivated:{os.linesep}{os.linesep.join(deactivated)}" ) else: # set the opt-out defaults if enable_checkpointing is None: enable_checkpointing = True if logger is None: logger = True if enable_progress_bar is None: enable_progress_bar = True if log_every_n_steps is None: log_every_n_steps = 50 if enable_model_summary is None: enable_model_summary = True if num_sanity_val_steps is None: num_sanity_val_steps = 2 # init connectors self._data_connector = _DataConnector(self) self._accelerator_connector = _AcceleratorConnector( devices=devices, accelerator=accelerator, strategy=strategy, num_nodes=num_nodes, sync_batchnorm=sync_batchnorm, benchmark=benchmark, use_distributed_sampler=use_distributed_sampler, deterministic=deterministic, precision=precision, plugins=plugins, ) self._logger_connector = _LoggerConnector(self) self._callback_connector = _CallbackConnector(self) self._checkpoint_connector = _CheckpointConnector(self) self._signal_connector = _SignalConnector(self) # init loops self.fit_loop = _FitLoop(self, min_epochs=min_epochs, max_epochs=max_epochs) self.fit_loop.epoch_loop = _TrainingEpochLoop(self, min_steps=min_steps, max_steps=max_steps) self.validate_loop = _EvaluationLoop(self, inference_mode=inference_mode) self.test_loop = _EvaluationLoop(self, inference_mode=inference_mode) self.predict_loop = _PredictionLoop(self, inference_mode=inference_mode) self.accumulate_grad_batches = accumulate_grad_batches # init callbacks # Declare attributes to be set in _callback_connector on_trainer_init self._callback_connector.on_trainer_init( callbacks, enable_checkpointing, enable_progress_bar, default_root_dir, enable_model_summary, max_time, ) # init data flags self.check_val_every_n_epoch: Optional[int] self._data_connector.on_trainer_init( val_check_interval, reload_dataloaders_every_n_epochs, check_val_every_n_epoch, ) # gradient clipping if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)): raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.") if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type( gradient_clip_algorithm.lower() ): raise MisconfigurationException( f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. " f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}." ) self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = ( GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None ) self._detect_anomaly: bool = detect_anomaly setup._log_device_info(self) self.should_stop = False self.state = TrainerState() # configure profiler setup._init_profiler(self, profiler) # init logger flags self._loggers: List[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags self.val_check_batch: Union[int, float] self.val_check_interval: Union[int, float] self.num_sanity_val_steps: Union[int, float] self.limit_train_batches: Union[int, float] self.limit_val_batches: Union[int, float] self.limit_test_batches: Union[int, float] self.limit_predict_batches: Union[int, float] setup._init_debugging_flags( self, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, fast_dev_run, overfit_batches, val_check_interval, num_sanity_val_steps, ) def fit( self, model: "pl.LightningModule", train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[str] = None, ) -> None: r"""Runs the full optimization routine. Args: model: Model to fit. train_dataloaders: An iterable or collection of iterables specifying training samples. Alternatively, a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.train_dataloader` hook. val_dataloaders: An iterable or collection of iterables specifying validation samples. datamodule: A :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.train_dataloader` hook. ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised. Raises: TypeError: If ``model`` is not :class:`~pytorch_lightning.core.module.LightningModule` for torch version less than 2.0.0 and if ``model`` is not :class:`~pytorch_lightning.core.module.LightningModule` or :class:`torch._dynamo.OptimizedModule` for torch versions greater than or equal to 2.0.0 . For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. """ model = _maybe_unwrap_optimized(model) self.strategy._lightning_module = model _verify_strategy_supports_compile(model, self.strategy) call._call_and_handle_interrupt( self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) def _fit_impl( self, model: "pl.LightningModule", train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[str] = None, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True # if a datamodule comes in as the second arg, then fix it for the user if isinstance(train_dataloaders, LightningDataModule): datamodule = train_dataloaders train_dataloaders = None # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( "You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`" ) # links data to the trainer self._data_connector.attach_data( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=True, model_connected=self.lightning_module is not None, ) self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.training = False return def validate( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the validation set. Args: model: The model to validate. dataloaders: An iterable or collection of iterables specifying validation samples. Alternatively, a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.val_dataloader` hook. ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. verbose: If True, prints the validation results. datamodule: A :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.val_dataloader` hook. For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like :meth:`~pytorch_lightning.LightningModule.validation_step` etc. The length of the list corresponds to the number of validation dataloaders used. Raises: TypeError: If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run. If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`. MisconfigurationException: If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these. RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. """ if model is None: # do we still have a reference from a previous call? if self.lightning_module is None: raise TypeError( "`Trainer.validate()` requires a `LightningModule` when it hasn't been passed in a previous run" ) else: model = _maybe_unwrap_optimized(model) self.strategy._lightning_module = model _verify_strategy_supports_compile(self.lightning_module, self.strategy) return call._call_and_handle_interrupt( self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule ) def _validate_impl( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK # -------------------- log.debug(f"{self.__class__.__name__}: trainer validate stage") self.state.fn = TrainerFn.VALIDATING self.state.status = TrainerStatus.RUNNING self.validating = True # if a datamodule comes in as the second arg, then fix it for the user if isinstance(dataloaders, LightningDataModule): datamodule = dataloaders dataloaders = None # If you supply a datamodule you can't supply val_dataloaders if dataloaders is not None and datamodule: raise MisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`") if model is None: model = self.lightning_module model_provided = False else: model_provided = True self.validate_loop.verbose = verbose # links data to the trainer self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) results = self._run(model, ckpt_path=ckpt_path) # remove the tensors from the validation results results = convert_tensors_to_scalars(results) assert self.state.stopped self.validating = False return results def test( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. Args: model: The model to test. dataloaders: An iterable or collection of iterables specifying test samples. Alternatively, a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.test_dataloader` hook. ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. verbose: If True, prints the test results. datamodule: A :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.test_dataloader` hook. For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like :meth:`~pytorch_lightning.LightningModule.test_step` etc. The length of the list corresponds to the number of test dataloaders used. Raises: TypeError: If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run. If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`. MisconfigurationException: If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these. RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. """ if model is None: # do we still have a reference from a previous call? if self.lightning_module is None: raise TypeError( "`Trainer.test()` requires a `LightningModule` when it hasn't been passed in a previous run" ) else: model = _maybe_unwrap_optimized(model) self.strategy._lightning_module = model _verify_strategy_supports_compile(self.lightning_module, self.strategy) return call._call_and_handle_interrupt( self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule ) def _test_impl( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK # -------------------- log.debug(f"{self.__class__.__name__}: trainer test stage") self.state.fn = TrainerFn.TESTING self.state.status = TrainerStatus.RUNNING self.testing = True # if a datamodule comes in as the second arg, then fix it for the user if isinstance(dataloaders, LightningDataModule): datamodule = dataloaders dataloaders = None # If you supply a datamodule you can't supply test_dataloaders if dataloaders is not None and datamodule: raise MisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`") if model is None: model = self.lightning_module model_provided = False else: model_provided = True self.test_loop.verbose = verbose # links data to the trainer self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) results = self._run(model, ckpt_path=ckpt_path) # remove the tensors from the test results results = convert_tensors_to_scalars(results) assert self.state.stopped self.testing = False return results def predict( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. Args: model: The model to predict with. dataloaders: An iterable or collection of iterables specifying predict samples. Alternatively, a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader` hook. datamodule: A :class:`~pytorch_lightning.core.datamodule.LightningDataModule` that defines the :class:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader` hook. return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. Raises: TypeError: If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run. If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`. MisconfigurationException: If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these. RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. See :ref:`Lightning inference section<deploy/production_basic:Predict step with your LightningModule>` for more. """ if model is None: # do we still have a reference from a previous call? if self.lightning_module is None: raise TypeError( "`Trainer.predict()` requires a `LightningModule` when it hasn't been passed in a previous run" ) else: model = _maybe_unwrap_optimized(model) self.strategy._lightning_module = model _verify_strategy_supports_compile(self.lightning_module, self.strategy) return call._call_and_handle_interrupt( self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) def _predict_impl( self, model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: # -------------------- # SETUP HOOK # -------------------- log.debug(f"{self.__class__.__name__}: trainer predict stage") self.state.fn = TrainerFn.PREDICTING self.state.status = TrainerStatus.RUNNING self.predicting = True self.predict_loop.return_predictions = return_predictions # type: ignore[assignment] # if a datamodule comes in as the second arg, then fix it for the user if isinstance(dataloaders, LightningDataModule): datamodule = dataloaders dataloaders = None if dataloaders is not None and datamodule: raise MisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`") if model is None: model = self.lightning_module model_provided = False else: model_provided = True # links data to the trainer self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) results = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.predicting = False return results def _run( self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self ) self.fit_loop.min_epochs = min_epochs self.fit_loop.max_epochs = max_epochs if self.barebones: # no progress bar in barebones can make it look like the Trainer hung rank_zero_info( "`Trainer(barebones=True)` started running. The progress bar is disabled so you might want to" " manually print the progress in your model." ) # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) # attach model to the strategy self.strategy.connect(model) self._callback_connector._attach_model_callbacks() self._callback_connector._attach_model_logging_functions() _verify_loop_configurations(self) # hook log.debug(f"{self.__class__.__name__}: preparing data") self._data_connector.prepare_data() # ---------------------------- # SET UP THE TRAINER # ---------------------------- log.debug(f"{self.__class__.__name__}: setting up strategy environment") self.strategy.setup_environment() self.__setup_profiler() call._call_setup_hook(self) # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) log.debug(f"{self.__class__.__name__}: configuring sharded model") call._call_configure_sharded_model(self) # allow user to setup in model sharded environment # reset logger connector self._logger_connector.reset_results() self._logger_connector.reset_metrics() # strategy will configure model and move it to the device self.strategy.setup(self) # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start") call._call_lightning_module_hook(self, "on_fit_start") _log_hyperparams(self) if self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) # restore optimizers, etc. log.debug(f"{self.__class__.__name__}: restoring training state") self._checkpoint_connector.restore_training_state() self._checkpoint_connector.resume_end() self._signal_connector.register_signal_handlers() # ---------------------------- # RUN THE TRAINER # ---------------------------- results = self._run_stage() # ---------------------------- # POST-Training CLEAN UP # ---------------------------- log.debug(f"{self.__class__.__name__}: trainer tearing down") self._teardown() if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_end") call._call_lightning_module_hook(self, "on_fit_end") log.debug(f"{self.__class__.__name__}: calling teardown hooks") call._call_teardown_hook(self) self.state.status = TrainerStatus.FINISHED self.state.stage = None return results def _teardown(self) -> None: """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; those are handled by :meth:`_call_teardown_hook`.""" self.strategy.teardown() loop = self._active_loop # loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn` if loop is not None: loop.teardown() self._logger_connector.teardown() self._signal_connector.teardown() def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # wait for all to join if on distributed self.strategy.barrier("run-stage") if self.evaluating: return self._evaluation_loop.run() if self.predicting: return self.predict_loop.run() if self.training: with isolate_rng(): self._run_sanity_check() with torch.autograd.set_detect_anomaly(self._detect_anomaly): self.fit_loop.run() return None raise RuntimeError(f"Unexpected state {self.state}") def _run_sanity_check(self) -> None: val_loop = self.fit_loop.epoch_loop.val_loop should_sanity_check = ( self.enable_validation and self.num_sanity_val_steps > 0 # do not sanity check if restarting because it would mess up the loaded state and not val_loop.restarting ) # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: stage = self.state.stage self.sanity_checking = True # reset logger connector self._logger_connector.reset_results() self._logger_connector.reset_metrics() call._call_callback_hooks(self, "on_sanity_check_start") # run eval step val_loop.run() call._call_callback_hooks(self, "on_sanity_check_end") # reset logger connector self._logger_connector.reset_results() self._logger_connector.reset_metrics() # reset the progress tracking state after sanity checking. we don't need to set the state before # because sanity check only runs when we are not restarting _reset_progress(val_loop) # restore the previous stage when the sanity check if finished self.state.stage = stage def __setup_profiler(self) -> None: assert self.state.fn is not None local_rank = self.local_rank if self.world_size > 1 else None self.profiler._lightning_module = proxy(self.lightning_module) self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir) """ Accelerator properties """ @property def accelerator(self) -> Accelerator: assert self.strategy.accelerator return self.strategy.accelerator @property def strategy(self) -> Strategy: return self._accelerator_connector.strategy @property def precision_plugin(self) -> PrecisionPlugin: return self.strategy.precision_plugin @property def global_rank(self) -> int: return self.strategy.global_rank @property def local_rank(self) -> int: # some strategies define a local rank return getattr(self.strategy, "local_rank", 0) @property def node_rank(self) -> int: # some strategies define a node rank return getattr(self.strategy, "node_rank", 0) @property def world_size(self) -> int: # some strategies define a world size return getattr(self.strategy, "world_size", 1) @property def num_nodes(self) -> int: return getattr(self.strategy, "num_nodes", 1) @property def device_ids(self) -> List[int]: """List of device indexes per node.""" devices = ( self.strategy.parallel_devices if isinstance(self.strategy, ParallelStrategy) else [self.strategy.root_device] ) assert devices is not None device_ids = [] for idx, device in enumerate(devices): if isinstance(device, torch.device): device_ids.append(device.index or idx) elif isinstance(device, int): device_ids.append(device) return device_ids @property def num_devices(self) -> int: """Number of devices the trainer uses per node.""" return len(self.device_ids) @property def lightning_module(self) -> "pl.LightningModule": # TODO: this is actually an optional return return self.strategy.lightning_module # type: ignore[return-value] @property def optimizers(self) -> List[Optimizer]: return self.strategy.optimizers @optimizers.setter def optimizers(self, new_optims: List[Optimizer]) -> None: self.strategy.optimizers = new_optims @property def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: return self.strategy.lr_scheduler_configs @property def precision(self) -> _PRECISION_INPUT_STR: return self.strategy.precision_plugin.precision @property def scaler(self) -> Optional[Any]: return getattr(self.precision_plugin, "scaler", None) @property def model(self) -> Optional[torch.nn.Module]: """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. """ return self.strategy.model """ General properties """ @property def log_dir(self) -> Optional[str]: """The directory for the current experiment. Use this to save images to, etc... .. note:: You must call this on all processes. Failing to do so will cause your program to stall forever. .. code-block:: python def training_step(self, batch, batch_idx): img = ... save_img(img, self.trainer.log_dir) """ if len(self.loggers) > 0: if not isinstance(self.loggers[0], TensorBoardLogger): dirpath = self.loggers[0].save_dir else: dirpath = self.loggers[0].log_dir else: dirpath = self.default_root_dir dirpath = self.strategy.broadcast(dirpath) return dirpath @property def is_global_zero(self) -> bool: """Whether this process is the global zero in multi-node training. .. code-block:: python def training_step(self, batch, batch_idx): if self.trainer.is_global_zero: print("in node 0, accelerator 0") """ return self.strategy.is_global_zero @property def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs return None @property def enable_validation(self) -> bool: """Check if we should run validation during training.""" return ( self.fit_loop.epoch_loop.val_loop._data_source.is_defined() and is_overridden("validation_step", self.lightning_module) and self.limit_val_batches > 0 ) @property def default_root_dir(self) -> str: """The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. """ if get_filesystem(self._default_root_dir).protocol == "file": return os.path.normpath(self._default_root_dir) return self._default_root_dir @property def early_stopping_callback(self) -> Optional[EarlyStopping]: """The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.early_stopping_callbacks return callbacks[0] if len(callbacks) > 0 else None @property def early_stopping_callbacks(self) -> List[EarlyStopping]: """A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @property def checkpoint_callback(self) -> Optional[Checkpoint]: """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property def checkpoint_callbacks(self) -> List[Checkpoint]: """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property def progress_bar_callback(self) -> Optional[ProgressBar]: """An instance of :class:`~pytorch_lightning.callbacks.progress.progress_bar.ProgressBar` found in the Trainer.callbacks list, or ``None`` if one doesn't exist.""" for c in self.callbacks: if isinstance(c, ProgressBar): return c return None @property def ckpt_path(self) -> Optional[_PATH]: """Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`, :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise.""" return self._checkpoint_connector._ckpt_path @ckpt_path.setter def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: """Allows you to manage which checkpoint is loaded statefully. .. code-block:: python trainer = Trainer() trainer.ckpt_path = "my/checkpoint/file.ckpt" trainer.fit(model) ... # you will be in charge of resetting this trainer.ckpt_path = None trainer.test(model) """ self._checkpoint_connector._ckpt_path = ckpt_path self._checkpoint_connector._user_managed = bool(ckpt_path) def save_checkpoint( self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None ) -> None: r"""Runs routine to create a checkpoint. Args: filepath: Path where checkpoint is saved. weights_only: If ``True``, will only save the model weights. storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin Raises: AttributeError: If the model is not attached to the Trainer before calling this method. """ if self.model is None: raise AttributeError( "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call" " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?" ) self._checkpoint_connector.save_checkpoint(filepath, weights_only=weights_only, storage_options=storage_options) """ State properties """ @property def interrupted(self) -> bool: return self.state.status == TrainerStatus.INTERRUPTED @property def training(self) -> bool: return self.state.stage == RunningStage.TRAINING @training.setter def training(self, val: bool) -> None: if val: self.state.stage = RunningStage.TRAINING elif self.training: self.state.stage = None @property def testing(self) -> bool: return self.state.stage == RunningStage.TESTING @testing.setter def testing(self, val: bool) -> None: if val: self.state.stage = RunningStage.TESTING elif self.testing: self.state.stage = None @property def predicting(self) -> bool: return self.state.stage == RunningStage.PREDICTING @predicting.setter def predicting(self, val: bool) -> None: if val: self.state.stage = RunningStage.PREDICTING elif self.predicting: self.state.stage = None @property def validating(self) -> bool: return self.state.stage == RunningStage.VALIDATING @validating.setter def validating(self, val: bool) -> None: if val: self.state.stage = RunningStage.VALIDATING elif self.validating: self.state.stage = None @property def evaluating(self) -> bool: return self.state.stage is not None and self.state.stage.evaluating @property def sanity_checking(self) -> bool: """Whether sanity checking is running. Useful to disable some hooks, logging or callbacks during the sanity checking. """ return self.state.stage == RunningStage.SANITY_CHECKING @sanity_checking.setter def sanity_checking(self, val: bool) -> None: if val: self.state.stage = RunningStage.SANITY_CHECKING elif self.sanity_checking: self.state.stage = None @property def received_sigterm(self) -> bool: """Whether a ``signal.SIGTERM`` signal was received. For example, this can be checked to exit gracefully. """ return self._signal_connector.received_sigterm """ Loop properties """ @property def global_step(self) -> int: """The number of optimizer steps taken (does not reset each epoch). This includes multiple optimizers (if enabled). """ return self.fit_loop.epoch_loop.global_step @property def current_epoch(self) -> int: """The current epoch, updated after the epoch end hooks are run.""" return self.fit_loop.epoch_progress.current.completed @property def max_epochs(self) -> Optional[int]: return self.fit_loop.max_epochs @property def min_epochs(self) -> Optional[int]: return self.fit_loop.min_epochs @property def max_steps(self) -> int: return self.fit_loop.max_steps @property def min_steps(self) -> Optional[int]: return self.fit_loop.min_steps @property def is_last_batch(self) -> bool: """Whether trainer is executing the last batch.""" return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property def train_dataloader(self) -> Optional[TRAIN_DATALOADERS]: """The training dataloader(s) used during ``trainer.fit()``.""" if (combined_loader := self.fit_loop._combined_loader) is not None: return combined_loader.iterables return None @property def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]: """The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``.""" if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None or ( combined_loader := self.validate_loop._combined_loader ) is not None: return combined_loader.iterables return None @property def test_dataloaders(self) -> Optional[EVAL_DATALOADERS]: """The test dataloader(s) used during ``trainer.test()``.""" if (combined_loader := self.test_loop._combined_loader) is not None: return combined_loader.iterables return None @property def predict_dataloaders(self) -> Optional[EVAL_DATALOADERS]: """The prediction dataloader(s) used during ``trainer.predict()``.""" if (combined_loader := self.predict_loop._combined_loader) is not None: return combined_loader.iterables return None @property def num_training_batches(self) -> Union[int, float]: """The number of training batches that will be used during ``trainer.fit()``.""" return self.fit_loop.max_batches @property def num_sanity_val_batches(self) -> List[Union[int, float]]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property def num_val_batches(self) -> List[Union[int, float]]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop.max_batches # if no trainer.fn is set, assume fit's validation # use the protected access, because it shouldn't return the sanity_val batches return self.fit_loop.epoch_loop.val_loop._max_batches @property def num_test_batches(self) -> List[Union[int, float]]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property def num_predict_batches(self) -> List[Union[int, float]]: """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @property def _evaluation_loop(self) -> _EvaluationLoop: if self.state.fn == TrainerFn.FITTING: return self.fit_loop.epoch_loop.val_loop if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop if self.state.fn == TrainerFn.TESTING: return self.test_loop raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope") @property def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionLoop]]: if self.training: return self.fit_loop if self.sanity_checking or self.evaluating: return self._evaluation_loop if self.predicting: return self.predict_loop return None """ Logging properties """ @property def logger(self) -> Optional[Logger]: """The first :class:`~pytorch_lightning.loggers.logger.Logger` being used.""" return self.loggers[0] if len(self.loggers) > 0 else None @logger.setter def logger(self, logger: Optional[Logger]) -> None: if not logger: self.loggers = [] else: self.loggers = [logger] @property def loggers(self) -> List[Logger]: """The list of :class:`~pytorch_lightning.loggers.logger.Logger` used. .. code-block:: python for logger in trainer.loggers: logger.log_metrics({"foo": 1.0}) """ return self._loggers @loggers.setter def loggers(self, loggers: Optional[List[Logger]]) -> None: self._loggers = loggers if loggers else [] @property def callback_metrics(self) -> _OUT_DICT: """The metrics available to callbacks. .. code-block:: python def training_step(self, batch, batch_idx): self.log("a_val", 2.0) callback_metrics = trainer.callback_metrics assert callback_metrics["a_val"] == 2.0 """ return self._logger_connector.callback_metrics @property def logged_metrics(self) -> _OUT_DICT: """The metrics sent to the loggers. This includes metrics logged via :meth:`~pytorch_lightning.core.module.LightningModule.log` with the :paramref:`~pytorch_lightning.core.module.LightningModule.log.logger` argument set. """ return self._logger_connector.logged_metrics @property def progress_bar_metrics(self) -> _PBAR_DICT: """The metrics sent to the progress bar. This includes metrics logged via :meth:`~pytorch_lightning.core.module.LightningModule.log` with the :paramref:`~pytorch_lightning.core.module.LightningModule.log.prog_bar` argument set. """ return self._logger_connector.progress_bar_metrics @property def _results(self) -> Optional[_ResultCollection]: active_loop = self._active_loop if active_loop is not None: return active_loop._results return None """ Other """ @property def estimated_stepping_batches(self) -> Union[int, float]: r""" The estimated number of batches that will ``optimizer.step()`` during training. This accounts for gradient accumulation and the current trainer configuration. This might sets up your training dataloader if hadn't been set up already. .. code-block:: python def configure_optimizers(self): optimizer = ... stepping_batches = self.trainer.estimated_stepping_batches scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches) return [optimizer], [scheduler] Raises: MisconfigurationException: If estimated stepping batches cannot be computed due to different `accumulate_grad_batches` at different epochs. """ # infinite training if self.max_epochs == -1: return float("inf") if self.max_steps == -1 else self.max_steps if self.train_dataloader is None: rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.") state = self.state self.state.fn = TrainerFn.FITTING self.training = True self.fit_loop.setup_data() self.state = state total_batches = self.num_training_batches # iterable dataset if total_batches == float("inf"): return self.max_steps assert self.max_epochs is not None max_estimated_steps = math.ceil(total_batches / self.accumulate_grad_batches) * max(self.max_epochs, 1) max_estimated_steps = min(max_estimated_steps, self.max_steps) if self.max_steps != -1 else max_estimated_steps return max_estimated_steps