import sys
import os
import warnings
from .model import CSRNet
from .utils import save_checkpoint
from .dataset import *
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np
import argparse
import json
import cv2
import time
import glob
import random
#améliorer : attention aux paths pour save les models
parser = argparse.ArgumentParser(description='PyTorch CSRNet')
parser.add_argument('train_json', metavar='TRAIN',
help='path to train json')
parser.add_argument('test_json', metavar='TEST',
help='path to test json')
parser.add_argument('--pre', '-p', metavar='PRETRAINED', default=None,type=str,
help='path to the pretrained model')
parser.add_argument('gpu',metavar='GPU', type=str,
help='GPU id to use.')
parser.add_argument('task',metavar='TASK', type=str,
help='task id to use.')
def main():
global args,best_prec1
best_prec1 = 1e6
args = parser.parse_args()
args.original_lr = 1e-7
args.lr = 1e-7
args.batch_size = 1
args.momentum = 0.95
args.decay = 5*1e-4
args.start_epoch = 0
args.epochs = 400
args.steps = [-1,1,100,150]
args.scales = [1,1,1,1]
args.workers = 4
args.seed = time.time()
args.print_freq = 30
with open(args.train_json, 'r') as outfile:
train_list = json.load(outfile)
with open(args.test_json, 'r') as outfile:
val_list = json.load(outfile)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.cuda.manual_seed(args.seed)
model = CSRNet()
model = model.cuda()
criterion = nn.MSELoss(size_average=False).cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.decay)
if args.pre:
if os.path.isfile(args.pre):
print("=> loading checkpoint '{}'".format(args.pre))
checkpoint = torch.load(args.pre)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.pre, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.pre))
for epoch in range(args.start_epoch, args.epochs):
args = adjust_learning_rate(optimizer, epoch,args,best_prec1)
train(train_list, model, criterion, optimizer, epoch, args, best_prec1)
prec1 = validate(val_list, model, criterion)
is_best = prec1 < best_prec1
best_prec1 = min(prec1, best_prec1)
print(' * best MAE {mae:.3f} '
.format(mae=best_prec1))
save_checkpoint({
'epoch': epoch + 1,
'arch': args.pre,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best,args.task)
def train(train_list, model, criterion, optimizer, epoch, args, best_prec1):
losses = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
train_loader = torch.utils.data.DataLoader(
listDataset(train_list,
shuffle=True,
transform=transforms.Compose([
transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
train=True,
seen=model.seen,
batch_size=args.batch_size,
num_workers=args.workers),
batch_size=args.batch_size)
print('epoch %d, processed %d samples, lr %.10f' % (epoch, epoch * len(train_loader.dataset), args.lr))
model.train()
end = time.time()
for i,(img, target)in enumerate(train_loader):
data_time.update(time.time() - end)
img = img.cuda()
img = Variable(img)
output = model(img)
target = target.type(torch.FloatTensor).unsqueeze(0).cuda()
target = Variable(target)
loss = criterion(output, target)
losses.update(loss.item(), img.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses))
def validate(val_list, model, criterion):
print ('begin test')
test_loader = torch.utils.data.DataLoader(
listDataset(val_list,
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]), train=False),
batch_size=args.batch_size)
model.eval()
mae = 0
for i,(img, target) in enumerate(test_loader):
img = img.cuda()
img = Variable(img)
output = model(img)
mae += abs(output.data.sum()-target.sum().type(torch.FloatTensor).cuda())
mae = mae/len(test_loader)
print(' * MAE {mae:.3f} '
.format(mae=mae))
return mae
[docs]def adjust_learning_rate(optimizer, epoch,args,best_prec1):
'Sets the learning rate to the initial LR decayed by 10 every 30 epochs'
args.lr = args.original_lr
for i in range(len(args.steps)):
scale = args.scales[i] if i < len(args.scales) else 1
if epoch >= args.steps[i]:
args.lr = args.lr * scale
if epoch == args.steps[i]:
break
else:
break
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr
return args
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class arg:
def __init__(self):
self.task = '0'
#Autre fonction que MSELOSS ou autre optimizer que SGD peuvent etre implémentés
[docs]def complete_train(datasetpath, modelpath = None, shuffle = True, gpu = True, init_lr = 1e-7, batch_size = 1,epochs = 400, img_format = '*.png',filename ='./checkpoint.pth.tar',best_model='./model_best.pth.tar'):
"""
Train a CSRNet model based on the received data.
:param datasetpath: the path to the folder containing the images
:param modelpath: path to a .tar file containing a already trained model
:param shuffle: If set to True, the images of the dataset will be shuffled before splitting into train and test set.
:param gpu: Use of th GPU to train the model
:param init_lr: initial learning rate
:param batch_size: batch size
:param epochs: number of epochs of the training
:param img_format: the format of the images, can only take the values '*.png' and '*.jpg'
:param filename: path and name where the checkpoint file will be saved after each epoch (must be a .tar file)
:param best_model: path and name where the best model (the best MAE on the test set) will be saved (must be a .tar file)
"""
global args,best_prec1
best_prec1 = 1e6
args = arg()
args.original_lr = init_lr
args.lr = 1e-7
args.batch_size = batch_size
args.gpu = gpu
args.momentum = 0.95
args.decay = 5*1e-4
args.start_epoch = 0
args.epochs = epochs
args.steps = [-1,1,100,150]
args.scales = [1,1,1,1]
args.workers = 4
args.seed = time.time()
args.print_freq = 30
args.pre = modelpath
json_paths = []
for json_path in glob.glob(os.path.join(datasetpath, img_format)):
json_paths.append(json_path)
if shuffle:
random.shuffle(json_paths)
l =int(3/4 * len(json_paths))
train_list = json_paths[:l]
val_list = json_paths[l:]
#os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#torch.cuda.manual_seed(args.seed)
model = CSRNet()
model = model.cuda()
criterion = nn.MSELoss(size_average=False).cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.decay)
if modelpath is not None:
if os.path.isfile(modelpath):
print("=> loading checkpoint '{}'".format(args.pre))
checkpoint = torch.load(args.pre)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.pre, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.pre))
for epoch in range(args.start_epoch, args.epochs):
args = adjust_learning_rate(optimizer, epoch,args, best_prec1)
train(train_list, model, criterion, optimizer, epoch, args, best_prec1)
prec1 = validate(val_list, model, criterion)
is_best = prec1 < best_prec1
best_prec1 = min(prec1, best_prec1)
print(' * best MAE {mae:.3f} '
.format(mae=best_prec1))
save_checkpoint({
'epoch': epoch + 1,
'arch': args.pre,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best,args.task,filename,best_model)
if __name__ == '__main__':
main()