Visualising CNN Models Using PyTorch*

Published: 02/09/2018  

Last Updated: 02/09/2018

Before any of the deep learning systems came along, researchers took a painstaking amount of time understanding the data. Finding visual cues before handing it off to an algorithm. But right now, we almost always feed our data into a transfer learning algorithm and hope it works even without tuning the hyper-parameters. And very often, this works. The current Convolutional Neural Network (CNN) models are very powerful and generalize well to new datasets. So training is quick and everyone is happy until running it on your test set where it bombs. You try to tune hyper-parameters, try a different pre-trained model but nothing works. This might be the right time to check your data and see if the data itself is right.

But then again, who has the time to go through all the data and make sure that everything is right. Or having the compute to try out multiple hyper-parameters and fine tune the model. So we can choose for the easier alternative of visualizing our model and checking what part of the image are causing the activations. This will give a very good understanding of the defining features of the image.

There is an urban legend that back in the 90’s, the US government commissioned for a project to detect tanks in a picture. The researchers built a neural network and used it classify the images. Once the product was actually put to test, it did not perform at all. On further inspection they noticed that the model had learnt the weather patterns instead of the tanks. The trained images with tanks were taken on a cloudy day and images with no tanks were taken on a sunny day. This is a prime example of how we need to understand the learnings by a neural net.

Usually once a deep learning model is trained, developers tend to use ROC curves or some other metric to measure the performance of the model. If the results are not particularly good, fine tuning the hyper parameters is often the solution. But this is a painstakingly long process. One of the best practices to do before re-training a model on different hyper parameters is to understand where the current model is going wrong. Had the US government kept on re-training models on different hyper parameters without understanding the data, it would have taken forever to figure out the mistake. Visualising the model is a great way to get an insight on to which features the model learnt. Any Intel powered CPUs could easily run this task. We just need to perform a forward pass on the model and store the activations of the required class. If you are looking for visualising multiple images, Intel AI Devcloud has you covered. 

Back in 2012, when AlexNet took the world by storm by winning the ImageNet challenge, they gave a brief description of the learning of convolutional kernels.

In this, you can observe that the initial layers are learning the dependencies like lines and edges. As you proceed further down in the image, more intricate dependencies are learnt. Check out the homepage of cs321n, a simple CNN runs live in your browser and the activations are shown in it.

In 2014, Karen Simonyan and their team won the ImageNet challenge. One of the key aspects that helped them win was a better understanding of the learning by CNNs. They plotted saliency maps to show the activations, and understood the work better.

Over the time the visualisations have gotten better. One of the most useful and easy to interpret activations is from Grad-cam: Gradient weighted class activations mapping. This technique uses class-specific gradient information flowing into the last layer to produce a coarse localisation map of the important regions in the image.

One of the biggest advantages of using visualisations is that we can understand which features are causing the activations. Recently, I was working on the ISIC challenge for skin cancer detection. I was achieving a probability of ~70%, when I inspected a few images run through grad-cam, I realised that the network was concentrating on the wrong features. It was giving a greater weightage to the skin color instead of the lesion.

Let’s dive into the code now

It is pretty straight forward. First, we load our trained model, then we define the target class. After that, we set all the gradients to zero and run a forward pass on the model. The activations in these gradients are then mapped onto the original image. We plot a heat map based on these activations on top of the original image. This will help in identifying the exact features that the model has learnt.

Required dependencies:

  • OpenCV*
  • PyTorch*
  • Torchvision* (optional)

We load the model into the memory and then the image. I trained my model on the ISIC 2017 challenge using a ResNet50, which I’m loading. If you have a different pre-trained model or else a model that you have defined, just load that into the checkpoint. If you notice, we are passing additional parameters to the torch.load function. This is to ensure that even if we have a model trained on a graphics processing unit (GPU), it can be used for inference on a central processing unit (CPU).

def load_checkpoint():
    """
        Loads the checkpoint of the trained model and returns the model.
    """
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        checkpoint = torch.load(opt.model)
    else:
        checkpoint = torch.load(
            opt.model, map_location=lambda storage, loc: storage)
            
    pretrained_model = models.resnet50(pretrained=True)
    num_ftrs = pretrained_model.fc.in_features
    pretrained_model.fc = nn.Linear(num_ftrs, 2)
    
    if use_gpu:
        pretrained_model = pretrained_model.cuda()
        
    pretrained_model.load_state_dict(checkpoint)
    pretrained_model.eval()

    return pretrained_model

Now we need to start processing the image. The transforms you used on the trained model need to be used again here. If there is a mean subtraction, that needs to be performed. Then load that into the variable for the forward pass.

    def preprocess_image(cv2im, resize_im=True):
        """
            Resizing the image as per parameter, converts it to a torch tensor and returns
            torch variable. 
        """
        if resize_im:
            cv2im = cv2.resize(cv2im, (224, 224))
        im_as_arr = np.float32(cv2im)
        im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
        im_as_arr = im_as_arr.transpose(2, 0, 1)
        im_as_ten = torch.from_numpy(im_as_arr).float()
        # Add one more channel to the beginning. Tensor shape = 1,3,224,224
        im_as_ten.unsqueeze_(0)
        # Convert to Pytorch variable
        im_as_var = Variable(im_as_ten, requires_grad=True)
        return im_as_var

Then we start the forward pass on the image and save only the target layer activations. Here the target layer needs to be the layer that we are going to visualize.

      def forward_pass_on_convolutions(self, x):
          """
              Does a forward pass on convolutions, hooks the function at given layer
          """
          conv_output = None
          for module_name, module in self.model._modules.items():
              print(module_name)
              if module_name == 'fc':
                  return conv_output, x
              x = module(x)  # Forward        
              if module_name == self.target_layer:
                  print('True')
                  x.register_hook(self.save_gradient)
                  conv_output = x  # Save the convolution output on that layer
          return conv_output, x
      
      def forward_pass(self, x):
          """
              Does a full forward pass on the model
          """
          # Forward pass on the convolutions
          conv_output, x = self.forward_pass_on_convolutions(x)
          x = x.view(x.size(0), -1)  # Flatten
          # Forward pass on the classifier
          x = self.model.fc(x)
          return conv_output, x

Now we need to call the function to execute the above defined functions. Below, we perform the forward pass along with the gradients of the target class. The code is well commented, you can understand the code by reading through it.

def generate_cam(self, input_image, target_index=None):
        """
            Full forward pass
            conv_output is the output of convolutions at specified layer
            model_output is the final output of the model            
        """      
        conv_output, model_output = self.extractor.forward_pass(input_image)
        if target_index is None:
            target_index = np.argmax(model_output.data.numpy())
        # Target for backprop
        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
        one_hot_output[0][target_index] = 1
        # Zero grads
        self.model.fc.zero_grad()
        # Backward pass with specified target
        model_output.backward(gradient=one_hot_output, retain_graph=True)
        # Get hooked gradients
        guided_gradients = self.extractor.gradients.data.numpy()[0]
        # Get convolution outputs
        target = conv_output.data.numpy()[0]
        # Get weights from gradients
        # Take averages for each gradient
        weights = np.mean(guided_gradients, axis=(1, 2))
        # Create empty numpy array for cam
        cam = np.ones(target.shape[1:], dtype=np.float32)
        # Multiply each weight with its conv output and then, sum
        for i, w in enumerate(weights):
            cam += w * target[i, :, :]
        cam = cv2.resize(cam, (224, 224))
        cam = np.maximum(cam, 0)
        cam = (cam - np.min(cam)) / (np.max(cam) -
                                     np.min(cam))  # Normalize between 0-1
        cam = np.uint8(cam * 255)  # Scale between 0-255 to visualize
        return cam

Now we need to save the cam activations on the original image like a heat map to visualize the areas of concentration. We save the image in three different formats, B/W format, heat map, and the heat map superimposed on top of the original report.

    def save_class_activation_on_image(org_img, activation_map, file_name):
        """
          Saves the activation map as a heatmap imposed on the original image.
        """
        if not os.path.exists('./results'):
            os.makedirs('./results')
        # Grayscale activation map
        path_to_file = os.path.join('./results', file_name + '_Cam_Grayscale.jpg')
        cv2.imwrite(path_to_file, activation_map)
        # Heatmap of activation map
        activation_heatmap = cv2.applyColorMap(activation_map, cv2.COLORMAP_HSV)
        path_to_file = os.path.join('./results', file_name + '_Cam_Heatmap.jpg')
        cv2.imwrite(path_to_file, activation_heatmap)
        # Heatmap on picture
        org_img = cv2.resize(org_img, (224, 224))
        img_with_heatmap = np.float32(activation_heatmap) + np.float32(org_img)
        img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
        path_to_file = os.path.join('./results', file_name + '_Cam_On_Image.jpg')
        cv2.imwrite(path_to_file, np.uint8(255 * img_with_heatmap))
    
python visualisation.py --img <path to the image> --target <target class> --model <path to the trained model> --export <name of the file to export>

Here is the entire gist of the script. To run the code you need to provide the input arguments.

As I have said earlier, this visualization helped me understand my skin cancer detection model. I will now show you the results from that model after I tuned it.

From the above images you can notice that in the non-cancerous images, the activations are on the left. The model was activating for that particular skin color. This gave me the insight to normalise the entire dataset by mean and standard deviation.

The entire code for this has been modified from the amazing repository by Utku Ozbulak. I hope this has been helpful. If you have any feedback or questions, I would love to answer them.

Product and Performance Information

1

Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex.