Training Module
- class Crowd_counting.train.CSRNet(*args: Any, **kwargs: Any)[source]
Bases:
ModuleThis class represents a CSRNet model
- class Crowd_counting.train.Dataset(*args, **kwds)[source]
Bases:
Generic[T_co]An abstract class representing a
Dataset.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite
__getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by manySamplerimplementations and the default options ofDataLoader.Note
DataLoaderby default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
- Crowd_counting.train.adjust_learning_rate(optimizer, epoch, args, best_prec1)[source]
Sets the learning rate to the initial LR decayed by 10 every 30 epochs
- Crowd_counting.train.complete_train(datasetpath, modelpath=None, shuffle=True, gpu=True, init_lr=1e-07, batch_size=1, epochs=400, img_format='*.png', filename='./checkpoint.pth.tar', best_model='./model_best.pth.tar')[source]
Train a CSRNet model based on the received data.
- Parameters
datasetpath – the path to the folder containing the images
modelpath – path to a .tar file containing a already trained model
shuffle – If set to True, the images of the dataset will be shuffled before splitting into train and test set.
gpu – Use of th GPU to train the model
init_lr – initial learning rate
batch_size – batch size
epochs – number of epochs of the training
img_format – the format of the images, can only take the values ‘.png’ and ‘.jpg’
filename – path and name where the checkpoint file will be saved after each epoch (must be a .tar file)
best_model – path and name where the best model (the best MAE on the test set) will be saved (must be a .tar file)