SegFormer学习笔记(4)train续2
创始人
2024-05-11 23:44:16

这次关注一下最关键的东西:用什么网络,用什么数据,预训练数据在哪里呢?

为了方便,重新贴一下 train.py

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from semseg.models import *
from semseg.datasets import * 
from semseg.augmentations import get_train_augmentation, get_val_augmentation
from semseg.losses import get_loss
from semseg.schedulers import get_scheduler
from semseg.optimizers import get_optimizer
from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from val import evaluatedef main(cfg, gpu, save_dir):start = time.time()best_mIoU = 0.0num_workers = mp.cpu_count()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#device = torch.device(cfg['DEVICE'])train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform)valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform)model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)model.init_pretrained(model_cfg['PRETRAINED'])model = model.to(device)if train_cfg['DDP']: sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)model = DDP(model, device_ids=[gpu])else:sampler = RandomSampler(trainset)trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=sampler)valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE']# class_weights = trainset.class_weights.to(device)loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])scheduler = get_scheduler(sched_cfg['NAME'], optimizer, epochs * iters_per_epoch, sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])scaler = GradScaler(enabled=train_cfg['AMP'])writer = SummaryWriter(str(save_dir / 'logs'))for epoch in range(epochs):model.train()if train_cfg['DDP']: sampler.set_epoch(epoch)train_loss = 0.0pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")for iter, (img, lbl) in pbar:optimizer.zero_grad(set_to_none=True)img = img.to(device)lbl = lbl.to(device)with autocast(enabled=train_cfg['AMP']):logits = model(img)loss = loss_fn(logits, lbl)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()scheduler.step()torch.cuda.synchronize()lr = scheduler.get_lr()lr = sum(lr) / len(lr)train_loss += loss.item()pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}")train_loss /= iter+1writer.add_scalar('train/loss', train_loss, epoch)torch.cuda.empty_cache()if (epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs:miou = evaluate(model, valloader, device)[-1]writer.add_scalar('val/mIoU', miou, epoch)if miou > best_mIoU:best_mIoU = mioutorch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}.pth")print(f"Current mIoU: {miou} Best mIoU: {best_mIoU}")writer.close()pbar.close()end = time.gmtime(time.time() - start)table = [['Best mIoU', f"{best_mIoU:.2f}"],['Total Training Time', time.strftime("%H:%M:%S", end)]]print(tabulate(table, numalign='right'))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--cfg', type=str, default='configs/custom.yaml', help='Configuration file to use')args = parser.parse_args()with open(args.cfg) as f:cfg = yaml.load(f, Loader=yaml.SafeLoader)fix_seeds(3407)setup_cudnn()gpu = setup_ddp()save_dir = Path(cfg['SAVE_DIR'])save_dir.mkdir(exist_ok=True)main(cfg, gpu, save_dir)cleanup_ddp()

一、model_cfg

上面第32行,本质说的是

model_cfg = cfg['MODEL']

你看在custom.yaml中,

MODEL:                                    NAME          : SegFormer                                           # name of the model you are usingBACKBONE      : MiT-B2                                                 # model variantPRETRAINED    : 'checkpoints/backbones/mit/mit_b2.pth'              # backbone model's weight

第42行,重量级代码来了

model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes)

model_cfg['NAME']其实就是'SegFormer'

这就需要细心地你,注意第15、16行如下:

from semseg.models import *
from semseg.datasets import * 

那么,第42行,就是要实现SegFormer类,并且BACKBONE 为 MiT-B2

第43行,说的是预训练模型

model.init_pretrained(model_cfg['PRETRAINED'])

你会发现,init_pretrained是个多态的,在这里,由于model已经是SegFormer类,而在SegFormer中,继承了BaseModel,所以,执行的是BaseModel的init_pretrained.

所以,43行执行的是啥?

model.init_pretrained(model_cfg['PRETRAINED'])

预训练模型来自model_cfg['PRETRAINED']

对于我来说,

PRETRAINED : 'checkpoints/backbones/mit/mit_b2.pth' # backbone model's weight

细心的你,

BACKBONE : MiT-B2 # model variant

还没用上呢。

再看一遍segformer.py:

import torch
from torch import Tensor
from torch.nn import functional as F
from semseg.models.base import BaseModel
from semseg.models.heads import SegFormerHeadclass SegFormer(BaseModel):def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19) -> None:super().__init__(backbone, num_classes)self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 768, num_classes)self.apply(self._init_weights)def forward(self, x: Tensor) -> Tensor:y = self.backbone(x)y = self.decode_head(y)   # 4x reduction in image sizey = F.interpolate(y, size=x.shape[2:], mode='bilinear', align_corners=False)    # to original image shapereturn yif __name__ == '__main__':model = SegFormer('MiT-B0')# model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b0.ade.pth', map_location='cpu'))x = torch.zeros(1, 3, 512, 512)y = model(x)print(y.shape)

上面第11行,就用上了backbone。

二、model_cfg总结

MODEL:                                    NAME          : SegFormer                                           # name of the model you are usingBACKBONE      : MiT-B2                                                 # model variantPRETRAINED    : 'checkpoints/backbones/mit/mit_b2.pth'              # backbone model's weight

NAME 决定了采用哪个类。

BACKBONE 决定了用哪个backbone

PRETRAINED 决定了预编译文件

他们之间是有约束关系的,不是随便乱选。

三、train_cfg

TRAIN:IMAGE_SIZE    : [512, 512]    # training image size in (h, w)BATCH_SIZE    : 2               # batch size used to trainEPOCHS        : 6             # number of epochs to trainEVAL_INTERVAL : 2             # evaluation interval during trainingAMP           : false           # use AMP in trainingDDP           : false           # use DDP training

四、dataset_cfg

DATASET:NAME          : HELEN                                          # dataset name to be trained with (camvid, cityscapes, ade20k)ROOT          : 'data/SmithCVPR2013_dataset_resized'                                      # dataset root pathIGNORE_LABEL  : 255

这里有意思不?

NAME : HELEN

怎么解释?

五、eval_cfg

EVAL:MODEL_PATH    : 'checkpoints/pretrained/ddrnet/ddrnet_23slim_city.pth'     # trained model file pathIMAGE_SIZE    : [1024, 1024]                            # evaluation image size in (h, w)                       MSF: ENABLE      : false                                   # multi-scale and flip evaluation  FLIP        : true                                    # use flip in evaluation  SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation    

相关内容

热门资讯

北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
男闺蜜是什么 女孩子的男闺蜜叫... 随着社会的开放,有关男女之别的传统思想也越来越弱化,身为女人,似乎没个异性朋友,没个玩得好的异性,就...
长白山自助游攻略 吉林长白山游... 昨天介绍了西坡的景点详细请看链接:一个人的旅行,据说能看到长白山天池全凭运气,您的运气如何?今日介绍...
应用未安装解决办法 平板应用未... ---IT小技术,每天Get一个小技能!一、前言描述苹果IPad2居然不能安装怎么办?与此IPad不...
世界上最漂亮的人 世界上最漂亮... 此前在某网上,选出了全球265万颜值姣好的女性。从这些数量庞大的女性群体中,人们投票选出了心目中最美...
猫咪吃了塑料袋怎么办 猫咪误食... 你知道吗?塑料袋放久了会长猫哦!要说猫咪对塑料袋的喜爱程度完完全全可以媲美纸箱家里只要一有塑料袋的响...
脚上的穴位图 脚面经络图对应的... 人体穴位作用图解大全更清晰直观的标注了各个人体穴位的作用,包括头部穴位图、胸部穴位图、背部穴位图、胳...
demo什么意思 demo版本... 618快到了,各位的小金库大概也在准备开闸放水了吧。没有小金库的,也该向老婆撒娇卖萌服个软了,一切只...
苗族的传统节日 贵州苗族节日有... 【岜沙苗族芦笙节】岜沙,苗语叫“分送”,距从江县城7.5公里,是世界上最崇拜树木并以树为神的枪手部落...
男闺蜜是什么 女孩子的男闺蜜叫... 随着社会的开放,有关男女之别的传统思想也越来越弱化,身为女人,似乎没个异性朋友,没个玩得好的异性,就...
北京的名胜古迹 北京最著名的景... 北京从元代开始,逐渐走上帝国首都的道路,先是成为大辽朝五大首都之一的南京城,随着金灭辽,金代从海陵王...
长白山自助游攻略 吉林长白山游... 昨天介绍了西坡的景点详细请看链接:一个人的旅行,据说能看到长白山天池全凭运气,您的运气如何?今日介绍...
脚上的穴位图 脚面经络图对应的... 人体穴位作用图解大全更清晰直观的标注了各个人体穴位的作用,包括头部穴位图、胸部穴位图、背部穴位图、胳...
应用未安装解决办法 平板应用未... ---IT小技术,每天Get一个小技能!一、前言描述苹果IPad2居然不能安装怎么办?与此IPad不...
demo什么意思 demo版本... 618快到了,各位的小金库大概也在准备开闸放水了吧。没有小金库的,也该向老婆撒娇卖萌服个软了,一切只...
猫咪吃了塑料袋怎么办 猫咪误食... 你知道吗?塑料袋放久了会长猫哦!要说猫咪对塑料袋的喜爱程度完完全全可以媲美纸箱家里只要一有塑料袋的响...
世界上最漂亮的人 世界上最漂亮... 此前在某网上,选出了全球265万颜值姣好的女性。从这些数量庞大的女性群体中,人们投票选出了心目中最美...