In [4]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torch import nn
from importlib.util import find_spec
if find_spec("text_recognizer") is None:
    import sys
    sys.path.append('..')
    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
from hydra import compose, initialize
from omegaconf import OmegaConf
from hydra.utils import instantiate

In [3]:
# context initialization
with initialize(config_path="../training/conf/model/", job_name="test_app"):
    cfg = compose(config_name="lit_transformer")
    print(OmegaConf.to_yaml(cfg))
    print(cfg)

_target_: text_recognizer.models.transformer.TransformerLitModel
interval: step
monitor: val/loss
start_token: <s>
end_token: <e>
pad_token: <p>

{'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}


In [None]:
# context initialization
with initialize(config_path="../training/conf/mapping/", job_name="test_app"):
    cfg = compose(config_name="word_piece")
    print(OmegaConf.to_yaml(cfg))
    print(cfg)

In [None]:
# context initialization
with initialize(config_path="../training/conf/network/", job_name="test_app"):
    cfg = compose(config_name="conv_transformer")
    print(OmegaConf.to_yaml(cfg))
    print(cfg)

In [6]:
# context initialization
with initialize(config_path="../training/conf/", job_name="test_app"):
    cfg = compose(config_name="config")
    print(OmegaConf.to_yaml(cfg))
    print(cfg)

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val/loss
    save_top_k: 1
    save_last: true
    mode: min
    verbose: false
    dirpath: checkpoints/
    filename:
      epoch:02d: null
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step
    log_momentum: false
  watch_model:
    _target_: callbacks.wandb_callbacks.WatchModel
    log: all
    log_freq: 100
  upload_code_as_artifact:
    _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
    project_dir: ${work_dir}/text_recognizer
  upload_ckpts_as_artifact:
    _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
    ckpt_dir: checkpoints/
    upload_best_only: true
  log_text_predictions:
    _target_: callbacks.wandb_callbacks.LogTextPredictions
    num_samples: 8
criterion:
  _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
  smoothing: 0.1
  ignore_index: 1002
datamodule:
 

In [10]:
cfg.get("callbacks")

{'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}

In [12]:
for l in cfg.callbacks.values():
    print(l.get("_target_"))

pytorch_lightning.callbacks.ModelCheckpoint
pytorch_lightning.callbacks.LearningRateMonitor
callbacks.wandb_callbacks.WatchModel
callbacks.wandb_callbacks.UploadCodeAsArtifact
callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
callbacks.wandb_callbacks.LogTextPredictions


In [4]:
mapping = instantiate(cfg.mapping)

2021-08-03 15:27:02.069 | DEBUG    | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb


In [5]:
network = instantiate(cfg.network)

In [None]:
OmegaConf.set_struct(cfg, False)

In [8]:
datamodule = instantiate(cfg.datamodule, mapping=mapping)

In [9]:
datamodule.prepare_data()
datamodule.setup()

2021-08-03 15:28:22.541 | INFO     | text_recognizer.data.iam_paragraphs:setup:95 - Loading IAM paragraph regions and lines for None...
2021-08-03 15:28:45.280 | INFO     | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...


In [11]:
len(datamodule.train_dataloader())

4992

In [None]:
mapping

In [7]:
config = cfg

In [8]:
loss_fn = instantiate(cfg.criterion)

In [9]:
import hydra

In [12]:
    model = hydra.utils.instantiate(
        config.model,
        mapping=mapping,
        network=network,
        loss_fn=loss_fn,
        optimizer_config=config.optimizer,
        lr_scheduler_config=config.lr_scheduler,
        _recursive_=False,
    )


In [11]:
mapping.get_index

<bound method WordPieceMapping.get_index of <text_recognizer.data.word_piece_mapping.WordPieceMapping object at 0x7fae3b489610>>

In [None]:
net = instantiate(cfg)

In [None]:
net

In [None]:
img = torch.rand(4, 1, 576, 640)

In [None]:
y = torch.randint(0, 1006, (4, 451))

In [None]:
y.shape

In [None]:
net = net.cuda()
img = img.cuda()
y = y.cuda()

In [None]:
net(img, y).shape

In [None]:
from torchsummary import summary

In [None]:
summary(net, [(1, 576, 640), (451,)], device="cpu", depth=2)