Training Module

class Crowd_counting.train.CSRNet(*args: Any, **kwargs: Any)[source]

Bases: Module

This class represents a CSRNet model

forward(x)[source]
class Crowd_counting.train.Dataset(*args, **kwds)[source]

Bases: Generic[T_co]

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

Note

DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

Crowd_counting.train.adjust_learning_rate(optimizer, epoch, args, best_prec1)[source]

Sets the learning rate to the initial LR decayed by 10 every 30 epochs

Crowd_counting.train.complete_train(datasetpath, modelpath=None, shuffle=True, gpu=True, init_lr=1e-07, batch_size=1, epochs=400, img_format='*.png', filename='./checkpoint.pth.tar', best_model='./model_best.pth.tar')[source]

Train a CSRNet model based on the received data.

Parameters
  • datasetpath – the path to the folder containing the images

  • modelpath – path to a .tar file containing a already trained model

  • shuffle – If set to True, the images of the dataset will be shuffled before splitting into train and test set.

  • gpu – Use of th GPU to train the model

  • init_lr – initial learning rate

  • batch_size – batch size

  • epochs – number of epochs of the training

  • img_format – the format of the images, can only take the values ‘.png’ and ‘.jpg’

  • filename – path and name where the checkpoint file will be saved after each epoch (must be a .tar file)

  • best_model – path and name where the best model (the best MAE on the test set) will be saved (must be a .tar file)