fix learning rate update error

fix learning rate update error

АвторTing-Chun Wang
КоммитерTing-Chun Wang
6 лет назад
Файлов изменено: 3
+10
–9
2e6d137
Родители a422816 master
models/base_model.py
@@ -1,4 +1,5 @@
import os, sys
import numpy as np
import torch
from .networks import get_grid
@@ -150,9 +151,9 @@
        edge[:,:,:,:-1,:] = edge[:,:,:,:-1,:] | (t[:,:,:,1:,:] != t[:,:,:,:-1,:])
        return edge.float()       
        
    def update_learning_rate(self, epoch):        
    def update_learning_rate(self, epoch, model):        
        lr = self.opt.lr * (- (epoch - self.opt.niter) / self.opt.niter_decay)
        for param_group in self.optimizer_D.param_groups:
        for param_group in getattr(self, 'optimizer_' + model).param_groups:
            param_group['lr'= lr
        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
models/models.py
@@ -101,7 +101,7 @@
            optimizer_D_T.append(getattr(modelD.module, 'optimizer_D_T'+str(s)))
    return modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T
def init_params(opt, modelG, modelD, dataset_size):
def init_params(opt, modelG, modelD, data_loader):
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    start_epoch, epoch_iter = 10
    ### if continue training, recover previous states
@@ -110,8 +110,8 @@
            start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)        
        print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))   
        if start_epoch opt.niter:
            modelG.module.update_learning_rate(start_epoch-1)
            modelD.module.update_learning_rate(start_epoch-1)
            modelG.module.update_learning_rate(start_epoch-1'G')
            modelD.module.update_learning_rate(start_epoch-1'D')
        if (opt.n_scales_spatial 1and (opt.niter_fix_global != 0and (start_epoch opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch opt.niter_step:
@@ -127,7 +127,7 @@
    output_nc = opt.output_nc         
    print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch-1* dataset_size + epoch_iter
    total_steps = (start_epoch-1* len(data_loader) + epoch_iter
    total_steps = total_steps // print_freq * print_freq  
    return n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, start_epoch, epoch_iter, print_freq, total_steps, iter_path
@@ -151,8 +151,8 @@
def update_models(opt, epoch, modelG, modelD, data_loader):
    ### linearly decay learning rate after certain iterations
    if epoch opt.niter:
        modelG.module.update_learning_rate(epoch)
        modelD.module.update_learning_rate(epoch)
        modelG.module.update_learning_rate(epoch, 'G')
        modelD.module.update_learning_rate(epoch, 'D')
    ### gradually grow training sequence length
    if (epoch % opt.niter_step) == 0:
@@ -30,7 +30,7 @@
    ### set parameters    
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, dataset_size)
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)    
    ### real training starts here  

Cherry-pick

Команда cherry-pick позволяет выбрать отдельные коммиты из одной ветки и применить их к другой.