fix learning rate update error
Автор
Ting-Chun Wang

Коммитер
Ting-Chun Wang
6 лет назад 
Файлов изменено: 3
+10
–9
2e6d137
models/base_model.py
+3
–2
@@ -1,4 +1,5 @@ | ||
import os, sys | ||
import numpy as np | ||
import torch | ||
from .networks import get_grid | ||
@@ -150,9 +151,9 @@ | ||
edge[:,:,:,:-1,:] = edge[:,:,:,:-1,:] | (t[:,:,:,1:,:] != t[:,:,:,:-1,:]) | ||
return edge.float() | ||
def update_learning_rate(self, epoch): | ||
def update_learning_rate(self, epoch, model): | ||
lr = self.opt.lr * (1 - (epoch - self.opt.niter) / self.opt.niter_decay) | ||
for param_group in self.optimizer_D.param_groups: | ||
for param_group in getattr(self, 'optimizer_' + model).param_groups: | ||
param_group['lr'] = lr | ||
print('update learning rate: %f -> %f' % (self.old_lr, lr)) | ||
self.old_lr = lr |
models/models.py
+6
–6
@@ -101,7 +101,7 @@ | ||
optimizer_D_T.append(getattr(modelD.module, 'optimizer_D_T'+str(s))) | ||
return modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T | ||
def init_params(opt, modelG, modelD, dataset_size): | ||
def init_params(opt, modelG, modelD, data_loader): | ||
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') | ||
start_epoch, epoch_iter = 1, 0 | ||
### if continue training, recover previous states | ||
@@ -110,8 +110,8 @@ | ||
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) | ||
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) | ||
if start_epoch > opt.niter: | ||
modelG.module.update_learning_rate(start_epoch-1) | ||
modelD.module.update_learning_rate(start_epoch-1) | ||
modelG.module.update_learning_rate(start_epoch-1, 'G') | ||
modelD.module.update_learning_rate(start_epoch-1, 'D') | ||
if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (start_epoch > opt.niter_fix_global): | ||
modelG.module.update_fixed_params() | ||
if start_epoch > opt.niter_step: | ||
@@ -127,7 +127,7 @@ | ||
output_nc = opt.output_nc | ||
print_freq = lcm(opt.print_freq, opt.batchSize) | ||
total_steps = (start_epoch-1) * dataset_size + epoch_iter | ||
total_steps = (start_epoch-1) * len(data_loader) + epoch_iter | ||
total_steps = total_steps // print_freq * print_freq | ||
return n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, start_epoch, epoch_iter, print_freq, total_steps, iter_path | ||
@@ -151,8 +151,8 @@ | ||
def update_models(opt, epoch, modelG, modelD, data_loader): | ||
### linearly decay learning rate after certain iterations | ||
if epoch > opt.niter: | ||
modelG.module.update_learning_rate(epoch) | ||
modelD.module.update_learning_rate(epoch) | ||
modelG.module.update_learning_rate(epoch, 'G') | ||
modelD.module.update_learning_rate(epoch, 'D') | ||
### gradually grow training sequence length | ||
if (epoch % opt.niter_step) == 0: |
train.py
+1
–1
@@ -30,7 +30,7 @@ | ||
### set parameters | ||
n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \ | ||
start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, dataset_size) | ||
start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader) | ||
visualizer = Visualizer(opt) | ||
### real training starts here |
Cherry-pick
Команда cherry-pick позволяет выбрать отдельные коммиты из одной ветки и применить их к другой.