python – Training a deep learning network for MRI reconstruction but loss and SSIM values ​​go to NaN after a few iterations

I’m currently working on a project to train a deep learning network to denoise MRI reconstructions in Pytorch but I’m running into issues in the training process where my loss and SSIM becomes NaN after a few iterations. From what I’ve gathered so far, it’s an issue with the gradients becoming too large and thus the loss becoming NaN but I’m not sure how to address that. The way I’m approaching this project is by feeding in pairs of images into a UNet, an input image that’s noisy because it was reconstructed with undersampled data and a target image which is clean because it was reconstructed with all the data, and then using MSELoss to train the network. Each scan I have is a 3D volume with flow data in the x, y, and z directions (256x256x256x4 where the four channels are magnitude, vx, vy, and vz) and I have 80 or so scans so my idea was to feed the data in slice-by-slice as a 4 channel image (256x256x4).

Here’s the relevant code:

def main():
    start = time.time()

    # Parameters
    params = {
        'data_dir': "/data/users/-----/denoising/data",
        'num_epochs': 10,
        'batch_size': 16,
        'lr': 0.0002,
        'step_size': 10,
        'gamma': 0.1,
        'momentum': 0.9,
        'weight_decay': 0.0005,
        'criterion': nn.MSELoss()
    }

    #log_directory
    log_dir = create_log(params)
    print("Log directory: {}".format(log_dir))

    # get the train and validation dataloaders
    dataloaders = get_dataloaders(params['data_dir'],params['batch_size'], shuffle=True)
    model = UNet(in_channels=4, out_channels=4, complex_input=False);
    add_to_log(log_dir, str(model))
    # print(model)

    # CUDA for PyTorch
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    # Code for multiple GPUs
    if torch.cuda.device_count()>1:
        num = torch.cuda.device_count()
        model = nn.DataParallel(model,device_ids = range(num))
        print("Using {} available gpus".format(num))
    else:
        print("Multiple gpus not found. Using only 1.")

    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'], weight_decay=params['weight_decay'])
    scheduler = lr_scheduler.StepLR(optimizer,step_size=params['step_size'],gamma=params['gamma'])

    train_model(dataloaders,model,params['criterion'],optimizer,scheduler,params['num_epochs'],log_dir,device)

    finish = time.time() - start
    add_to_log(log_dir, "Training complete in {:.0f}:{:.0f} min:sec".format(finish/60,finish%60))

main()

This is how I’ve structured my dataset:

class DenoisingDataset(Dataset):
    
    def __init__(self, input_path, val_set, test_set, transform=None):
        super().__init__()
        self.input_path = input_path
        self.val_set = [val_set]
        self.test_set = [test_set]
        self.indices = {"train": [], "val": [], "test": []}
        self.database = []
        self.transform = transform

        self._make_id_list()

    def __getitem__(self, index):
        x, y = self.read_data(index)
        x, y = self.preprocess(x, y)
        return {"input": x, "target": y}

    def __len__(self):
        return len(self.database)

    def _make_id_list(self):
        p = Path(self.input_path)
        assert(p.is_dir())
        for root, dirs, files in os.walk(p, topdown=False):
            for i, f in enumerate(files):
                if f.endswith("full.h5"):
                    parts = root.split("/")
                    # print(parts[-2])
                    # filename = os.path.join(root, f)
                    for j in range(256):
                        id = "{}-{}-{}".format(parts[-2], parts[-1], j)
                        # print(id)
                        self.database.append(id)
                        if parts[-2] in self.val_set:
                            self.indices["val"].append(i*256+j)
                        elif parts[-2] in self.test_set:
                            self.indices["test"].append(i*256+j)
                        else:
                            self.indices["train"].append(i*256+j)

        if len(self.database) < 1:
            raise RuntimeError('No data found.')

    def read_data(self, index):
        id = self.database[index]
        parts = id.split("-")
        input_filepath = "{}/input/{}/{}/Flow_undersampled.h5".format(self.input_path, parts[0], parts[1])
        with h5py.File(input_filepath, 'r') as h5_infile:
            input = self.stack_channels(h5_infile, int(parts[2]))
        target_filepath = "{}/target/{}/{}/Flow_full.h5".format(self.input_path, parts[0], parts[1])
        with h5py.File(target_filepath, 'r') as h5_tarfile:
            target = self.stack_channels(h5_tarfile, int(parts[2]))
        # print(input.shape)
        # print(target.shape)
        return input, target

    def preprocess(self, input, target):
        if self.transform:
            input = self.transform(input)
            target = self.transform(target)
        else:
            input = torch.from_numpy(input)
            target = torch.from_numpy(target)
        return input, target

    def stack_channels(self, h5_file, frame):
        mag = h5_file["Data/MAG"][:,:,frame]
        v1 = h5_file["Data/comp_vd_1"][:,:,frame]
        v2 = h5_file["Data/comp_vd_2"][:,:,frame]
        v3 = h5_file["Data/comp_vd_3"][:,:,frame]
        return np.transpose(np.dstack((mag,v1,v2,v3)), (2, 0, 1))

def get_dataloaders(input_dir, val_set="H3_051713", test_set="P3_090913", transform=None, batch_size=16, shuffle=False):

    print("Initializing dataset.")
    dataset = DenoisingDataset(input_dir, val_set, test_set, transform)
    train_indices, val_indices, test_indices = dataset.indices["train"], dataset.indices["val"], dataset.indices["test"]

    if shuffle:
        seed = 69
        np.random.seed(seed)
        np.random.shuffle(train_indices)
        np.random.shuffle(val_indices)

    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size)
    val_loader = DataLoader(dataset, sampler=val_sampler, batch_size=batch_size)
    test_loader = DataLoader(dataset, sampler=test_sampler, batch_size=batch_size)

    dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

    return dataloaders

And here’s my training loop:

def train_model(dataloaders,model,criterion,optimizer,scheduler,num_epochs,log_dir,device):

    num_images = 0
    best_ssim = 0

    for epoch in range(num_epochs):
        start = time.time()
        running_loss = 0.0
        ssim_vals = []

        ## set model in training or validation mode

        if ((epoch+1)%5==0 and epoch!=0):
            phase = "val" 
            model.eval()
        else:
            phase = "train"
            model.train()

        epoch_msg = "Epoch {}/{} - {}n".format(epoch,num_epochs-1, phase) + "-"*50
        print(epoch_msg)
        add_to_log(log_dir, epoch_msg)

        running_loss = 0.0
        #iterate over the data
        for i,sampled_batch in enumerate(tqdm(dataloaders[phase])):
            # print("Batch {}".format(i))
            inputs = sampled_batch["input"].float().to(device)  # N,C,H,W
            # print(inputs.shape)
            targets = sampled_batch["target"].float().to(device)  # N,C,H,W
            # print(targets.shape)

            num_images += inputs.size(0)
            # zero the parameter gradients
            optimizer.zero_grad()

            #forward
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                # print(predicted)
                loss = criterion(outputs,targets)
                # print(loss)

                #backward + optimize only in training phase
                if phase == "train":
                    loss.backward()
                    optimizer.step()
                    scheduler.step()


            #metric evaluation
            running_loss += loss.item()
            ssims = ms_ssim(outputs, targets, data_range=1, size_average=False).detach().cpu().numpy()
            # print(ssims)
            ssim_vals.append(ssims)

        # calculate the minibatch loss and SSIM
        epoch_loss =  running_loss/num_images
        epoch_ssim = np.nanmean(np.asarray(ssim_vals))
        print("Epoch: {}, Phase: {}, Loss: {:.4f}, SSIM: {}".format(phase,epoch,epoch_loss,epoch_ssim))
        add_to_log(log_dir, "Epoch: {}, Phase: {}, Loss: {:.4f}, SSIM: {}".format(phase,epoch,epoch_loss,epoch_ssim))
        
        # saving the model
        if phase == "val" and epoch_ssim > best_ssim:
            best_ssim = epoch_ssim
            torch.save({"epoch": epoch,
                        "model_state_dict":model.state_dict(),
                        "optimizer_state_dict":optimizer.state_dict(),
                        "loss":epoch_loss,
                        "SSIM":epoch_ssim
                    },(os.path.join(log_dir,'train_exp-epoch{}.pt'.format(epoch))))
        else:
            pass

        elapsed = time.time()-start
        msg = "Epoch complete in {:.0f}:{:.0f} min:secs".format(elapsed/60,elapsed%60)
        print(msg)
        add_to_log(log_dir, msg)

This is what the structure of my network looks like:

UNet(
  (encoders): ModuleList(
    (0): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
      (downsampler): Identity()
    )
    (1): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
      (downsampler): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2), groups=64, bias=False)
    )
    (2): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
      (downsampler): Conv2d(128, 128, kernel_size=(2, 2), stride=(2, 2), groups=128, bias=False)
    )
    (3): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
      (downsampler): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2), groups=256, bias=False)
    )
  )
  (decoders): ModuleList(
    (0): Decoder(
      (upsample): ConvTranspose2d(512, 512, kernel_size=(2, 2), stride=(2, 2), groups=512, bias=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
    )
    (1): Decoder(
      (upsample): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2), groups=256, bias=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
    )
    (2): Decoder(
      (upsample): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), groups=128, bias=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (conv0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (conv0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (ReLU1): ReLU(inplace=True)
        )
      )
    )
  )
  (final_conv): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (residual_conv): ScaleLayer()
)

This code runs no problem, it’s just the output that is no good.

Epoch 0/9
--------------------------------------------------
Epoch: train, Phase: 0, Loss: nan, SSIM: 0.18735374510288239
Epoch complete in 59:18 min:sec
Epoch 1/9
--------------------------------------------------
Epoch: train, Phase: 1, Loss: nan, SSIM: nan
Epoch complete in 59:6 min:sec
Epoch 2/9
--------------------------------------------------
Epoch: train, Phase: 2, Loss: nan, SSIM: nan
Epoch complete in 59:7 min:sec
Epoch 3/9
--------------------------------------------------
Epoch: train, Phase: 3, Loss: nan, SSIM: nan
Epoch complete in 58:40 min:sec
Epoch 4/9

So would anybody more knowledgeable on deep learning stuff be able to help point out what might be the issue here?

Leave a Comment