Source code for Crowd_counting.model

"""
This module contain all methods to create and use a CSRNet model for crowd counting
"""

import torch.nn as nn
import torch
from torchvision import models
from .utils import save_net,load_net
import torch.nn.functional as F
import torchvision
import PIL.Image as Image
from matplotlib import pyplot as plt
from .image import *
from matplotlib import cm as CM
from torchvision import datasets, transforms
import pkg_resources
import io

[docs]class CSRNet(nn.Module): """This class represents a CSRNet model""" def __init__(self, load_weights=False): super(CSRNet, self).__init__() self.seen = 0 self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] self.backend_feat = [512, 512, 512,256,128,64] self.frontend = _make_layers(self.frontend_feat) self.backend = _make_layers(self.backend_feat,in_channels = 512,dilation = True) self.output_layer = nn.Conv2d(64, 1, kernel_size=1) if not load_weights: vgg = models.vgg16(pretrained = True) vgg_dict= vgg.state_dict() front_dict = self.frontend.state_dict() #self._initialize_weights() model_dict = [] for key in front_dict: model_dict.append("features."+ key) # 1. filter out unnecessary keys vgg_dict = {k: v for k, v in vgg_dict.items() if k in model_dict} vgg_dict_update = {} for key in vgg_dict.keys(): splits = key.split(".") newkey = splits[1] + "." + splits[2] vgg_dict_update[newkey] = vgg_dict[key] # 2. overwrite entries in the existing state dict front_dict.update(vgg_dict_update) # 3. load the new state dict self.frontend.load_state_dict(front_dict)
[docs] def forward(self,x): x = self.frontend(x) x = self.backend(x) x = self.output_layer(x) return x
def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0)
def _make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False): if dilation: d_rate = 2 else: d_rate = 1 layers = [] for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers)
[docs]def load_model(model_path, use_gpu = True): """ Load a stored already trained CSRNet model. :param model_path: the path to the stored model :param use_gpu: if the model is loaded on gpu or cpu :returns: the loaded model :rtype: CSRNet """ transform=transforms.Compose([ transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) model = CSRNet() if use_gpu: model = model.cuda() else: model = model.cpu() checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['state_dict']) return model
[docs]def load_pretrained(model_name = 'shanghaiA'): """ Load one of the already pretrained models of the librairy :param model_name: name of the pretrained model to be charged. Can be shanghaiA, shanghaiB, A10 :returns: the loaded model :rtype: CSRNet """ import requests #actual_name = 'models/' + model_name + '.pth.tar' #model_path = pkg_resources.resource_filename('Crowd_counting', actual_name) URL = 'https://tfecounting.000webhostapp.com/' + model_name + '.pth.tar' model_path = requests.get(URL) stream = io.BytesIO(model_path.content) return load_model(stream)
[docs]def predict(model,image_path, use_gpu = True): """ Predict the stimated number of people and density map from an image :param model: a CSRNet already trained model :param image_path: path to the image to be predicted :param use_gpu: if the model is loaded on gpu or cpu :returns: the estimated number of people :returns: the estimated density map """ from torchvision import datasets, transforms transform=transforms.Compose([ transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = transform(Image.open(image_path).convert('RGB')).cuda() output = model(img.unsqueeze(0)) people_nbr = int(output.detach().cpu().sum().numpy()) return people_nbr, output
[docs]def visualize(image, ground_truth = None, model = None, figsize = (100,100)): """ Visualize an image and eventually its ground_truth and one model prediction using matplotlib :param image: the path to base image :param ground_truth: a .h5 file corresponding to the ground_truth of the image :param model: a model used to make a prediction on the base image """ if isinstance(ground_truth, str): gt_file = h5py.File(ground_truth) else: gt_file = ground_truth count = 1 #number of things to plot : if ground_truth is not None: count+=1 gt = np.asarray(gt_file['density']) if model is not None: count+=1 people_nbr, output = predict(model,image) plt.figure(figsize = figsize) plt.subplot(1,count,1) plt.axis('off') plt.imshow(Image.open(image)) plt.title("Base Image", fontsize=75) if ground_truth is not None: plt.subplot(1,count,2) plt.axis('off') plt.imshow(gt,cmap=CM.jet) plt.title("Groundtruth : " + str(int(np.sum(gt))), fontsize=75) if model is not None: plt.subplot(1,count,count) plt.axis('off') plt.imshow(np.squeeze(output.detach().cpu().numpy(),(0,1)),cmap=CM.jet) plt.title("Model prediction : " + str(people_nbr), fontsize=75) plt.show()
[docs]def evaluate(model,img_paths, MAE=True, MSE=True): """ Compute the MSE-MAE of the model on a designated set of images :param model: a CSRNet already trained model :param list image_paths: paths to the images used for evaluation :param bool MAE: If set to False the MAE is not computed :param bool MSE: If set to False the MSE is not computed :returns: the MAE of the model (0 if the MAE parameter is set to False) :returns: the MSE of the model (0 if the MSE parameter is set to False) """ mae = 0 mse = 0 length = len(img_paths) transform=transforms.Compose([ transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) for i in range(length): img = transform(Image.open(img_paths[i]).convert('RGB')).cuda() gt_file = h5py.File(img_paths[i].replace('.png','.h5').replace('.jpg','.h5').replace('images','ground_truth'),'r') groundtruth = np.asarray(gt_file['density']) output = model(img.unsqueeze(0)) error = output.detach().cpu().sum().numpy()-np.sum(groundtruth) if MAE: mae += abs(error) if MSE: mse += (error)**2 return mae/length, mse/length