From bd4bd443f339e95007bfdabf3e060db720f4d4b9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 3 Aug 2021 18:18:48 +0200 Subject: Training working, multiple bug fixes --- training/utils.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) (limited to 'training/utils.py') diff --git a/training/utils.py b/training/utils.py index ef74f61..d23396e 100644 --- a/training/utils.py +++ b/training/utils.py @@ -17,6 +17,10 @@ from tqdm import tqdm import wandb +def print_config(config: DictConfig) -> None: + print(OmegaConf.to_yaml(config)) + + @rank_zero_only def configure_logging(config: DictConfig) -> None: """Configure the loguru logger for output to terminal and disk.""" @@ -30,7 +34,7 @@ def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]: callbacks = [] if config.get("callbacks"): for callback_config in config.callbacks.values(): - if config.get("_target_"): + if callback_config.get("_target_"): log.info(f"Instantiating callback <{callback_config._target_}>") callbacks.append(hydra.utils.instantiate(callback_config)) return callbacks @@ -41,8 +45,8 @@ def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]: logger = [] if config.get("logger"): for logger_config in config.logger.values(): - if config.get("_target_"): - log.info(f"Instantiating callback <{logger_config._target_}>") + if logger_config.get("_target_"): + log.info(f"Instantiating logger <{logger_config._target_}>") logger.append(hydra.utils.instantiate(logger_config)) return logger @@ -67,17 +71,8 @@ def extras(config: DictConfig) -> None: # Debuggers do not like GPUs and multiprocessing. if config.trainer.get("gpus"): config.trainer.gpus = 0 - if config.datamodule.get("pin_memory"): - config.datamodule.pin_memory = False - if config.datamodule.get("num_workers"): - config.datamodule.num_workers = 0 - - # Force multi-gpu friendly config. - accelerator = config.trainer.get("accelerator") - if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: - log.info( - f"Forcing ddp friendly configuration! " - ) + if config.trainer.get("precision"): + config.trainer.precision = 32 if config.datamodule.get("pin_memory"): config.datamodule.pin_memory = False if config.datamodule.get("num_workers"): -- cgit v1.2.3-70-g09d2