SegFormer学习笔记(4)train续2
创始人
2024-05-11 23:44:16

这次关注一下最关键的东西:用什么网络,用什么数据,预训练数据在哪里呢?

为了方便,重新贴一下 train.py

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from semseg.models import *
from semseg.datasets import * 
from semseg.augmentations import get_train_augmentation, get_val_augmentation
from semseg.losses import get_loss
from semseg.schedulers import get_scheduler
from semseg.optimizers import get_optimizer
from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from val import evaluatedef main(cfg, gpu, save_dir):start = time.time()best_mIoU = 0.0num_workers = mp.cpu_count()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#device = torch.device(cfg['DEVICE'])train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform)valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform)model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)model.init_pretrained(model_cfg['PRETRAINED'])model = model.to(device)if train_cfg['DDP']: sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)model = DDP(model, device_ids=[gpu])else:sampler = RandomSampler(trainset)trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=sampler)valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE']# class_weights = trainset.class_weights.to(device)loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])scheduler = get_scheduler(sched_cfg['NAME'], optimizer, epochs * iters_per_epoch, sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])scaler = GradScaler(enabled=train_cfg['AMP'])writer = SummaryWriter(str(save_dir / 'logs'))for epoch in range(epochs):model.train()if train_cfg['DDP']: sampler.set_epoch(epoch)train_loss = 0.0pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")for iter, (img, lbl) in pbar:optimizer.zero_grad(set_to_none=True)img = img.to(device)lbl = lbl.to(device)with autocast(enabled=train_cfg['AMP']):logits = model(img)loss = loss_fn(logits, lbl)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()scheduler.step()torch.cuda.synchronize()lr = scheduler.get_lr()lr = sum(lr) / len(lr)train_loss += loss.item()pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}")train_loss /= iter+1writer.add_scalar('train/loss', train_loss, epoch)torch.cuda.empty_cache()if (epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs:miou = evaluate(model, valloader, device)[-1]writer.add_scalar('val/mIoU', miou, epoch)if miou > best_mIoU:best_mIoU = mioutorch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}.pth")print(f"Current mIoU: {miou} Best mIoU: {best_mIoU}")writer.close()pbar.close()end = time.gmtime(time.time() - start)table = [['Best mIoU', f"{best_mIoU:.2f}"],['Total Training Time', time.strftime("%H:%M:%S", end)]]print(tabulate(table, numalign='right'))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--cfg', type=str, default='configs/custom.yaml', help='Configuration file to use')args = parser.parse_args()with open(args.cfg) as f:cfg = yaml.load(f, Loader=yaml.SafeLoader)fix_seeds(3407)setup_cudnn()gpu = setup_ddp()save_dir = Path(cfg['SAVE_DIR'])save_dir.mkdir(exist_ok=True)main(cfg, gpu, save_dir)cleanup_ddp()

一、model_cfg

上面第32行,本质说的是

model_cfg = cfg['MODEL']

你看在custom.yaml中,

MODEL:                                    NAME          : SegFormer                                           # name of the model you are usingBACKBONE      : MiT-B2                                                 # model variantPRETRAINED    : 'checkpoints/backbones/mit/mit_b2.pth'              # backbone model's weight

第42行,重量级代码来了

model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)

model_cfg['NAME']其实就是'SegFormer'

这就需要细心地你,注意第15、16行如下:

from semseg.models import *
from semseg.datasets import * 

那么,第42行,就是要实现SegFormer类,并且BACKBONE 为 MiT-B2

第43行,说的是预训练模型

model.init_pretrained(model_cfg['PRETRAINED'])

你会发现,init_pretrained是个多态的,在这里,由于model已经是SegFormer类,而在SegFormer中,继承了BaseModel,所以,执行的是BaseModel的init_pretrained.

所以,43行执行的是啥?

model.init_pretrained(model_cfg['PRETRAINED'])

预训练模型来自model_cfg['PRETRAINED']

对于我来说,

PRETRAINED : 'checkpoints/backbones/mit/mit_b2.pth' # backbone model's weight

细心的你,

BACKBONE : MiT-B2 # model variant

还没用上呢。

再看一遍segformer.py:

import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHeadclass SegFormer(BaseModel):def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:super().__init__(backbone, num_classes)self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)self.apply(self._init_weights)def forward(self, x: Tensor) -> Tensor:y = self.backbone(x)y = self.decode_head(y)   # 4x reduction in image sizey = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False)    # to original image shapereturn yif __name__ == '__main__':model = SegFormer('MiT-B0')# model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu'))x = torch.zeros(1, 3, 512, 512)y = model(x)print(y.shape)

上面第11行,就用上了backbone。

二、model_cfg总结

MODEL:                                    NAME          : SegFormer                                           # name of the model you are usingBACKBONE      : MiT-B2                                                 # model variantPRETRAINED    : 'checkpoints/backbones/mit/mit_b2.pth'              # backbone model's weight

NAME 决定了采用哪个类。

BACKBONE 决定了用哪个backbone

PRETRAINED 决定了预编译文件

他们之间是有约束关系的,不是随便乱选。

三、train_cfg

TRAIN:IMAGE_SIZE    : [512, 512]    # training image size in (h, w)BATCH_SIZE    : 2               # batch size used to trainEPOCHS        : 6             # number of epochs to trainEVAL_INTERVAL : 2             # evaluation interval during trainingAMP           : false           # use AMP in trainingDDP           : false           # use DDP training

四、dataset_cfg

DATASET:NAME          : HELEN                                          # dataset name to be trained with (camvid, cityscapes, ade20k)ROOT          : 'data/SmithCVPR2013_dataset_resized'                                      # dataset root pathIGNORE_LABEL  : 255

这里有意思不?

NAME : HELEN

怎么解释?

五、eval_cfg

EVAL:MODEL_PATH    : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth'     # trained model file pathIMAGE_SIZE    : [1024, 1024]                            # evaluation image size in (h, w)                       MSF: ENABLE      : false                                   # multi-scale and flip evaluation  FLIP        : true                                    # use flip in evaluation  SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation    

相关内容

热门资讯

猫咪吃了塑料袋怎么办 猫咪误食... 你知道吗?塑料袋放久了会长猫哦!要说猫咪对塑料袋的喜爱程度完完全全可以媲美纸箱家里只要一有塑料袋的响...
demo什么意思 demo版本... 618快到了,各位的小金库大概也在准备开闸放水了吧。没有小金库的,也该向老婆撒娇卖萌服个软了,一切只...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...