OpenMMlab第九课

xiaohui_82 2023-06-17 20:47:18

 

这节课学习了mmsegmentation的代码课程,包括MMSegementation的安装、推理、训练和测试。

1.安装

git clone https://github.com/open-mmlab/mmsegmentation.git -b dev-1.x

cd mmsegmentation
pip install -e .

2.下载训练权重到checkpoint目录

# 从 Model Zoo 获取 PSPNet 预训练模型,下载并保存在 checkpoint 文件夹中
wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint

3.下载数据集到data目录

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip -P data

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Labelme.zip-P data
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf -O /home/***/miniconda3/envs/pytorch/lib/python3.10/site-packages/matplotlib/mpl-data/fonts/ttf/SimHei.ttf

测试环境 

# 检查 Pytorch
import torch, torchvision
print('Pytorch 版本', torch.__version__)
print('CUDA 是否可用',torch.cuda.is_available())

# 检查 mmcv
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
print('MMCV版本', mmcv.__version__)
print('CUDA版本', get_compiling_cuda_version())
print('编译器版本', get_compiler_version())

# 检查 mmsegmentation
import mmseg
from mmseg.utils import register_all_modules
from mmseg.apis import inference_model, init_model
print('mmsegmentation版本', mmseg.__version__)

import matplotlib 
import matplotlib.pyplot as plt
matplotlib.rc("font",family='SimHei') # 中文字体
plt.plot([1,2,3], [100,500,300])
plt.title('matplotlib中文字体测试', fontsize=25)
plt.xlabel('X轴', fontsize=15)
plt.ylabel('Y轴', fontsize=15)
plt.show()

Pytorch 版本 2.0.1+cu117
CUDA 是否可用 True
MMCV版本 2.0.0rc4
CUDA版本 11.1
编译器版本 GCC 9.4
mmsegmentation版本 1.0.0

4.可视化数据集

训练单个图片

python demo/image_demo.py \
        data/Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg \
        configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
        https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
        --out-file outputs/B1_uk_pspnet.jpg \
        --device cuda:0 \
        --opacity 0.5

 

 

python demo/image_demo.py \
        data/Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg \
        configs/segformer/segformer_mit-b5_8xb2-160k_ade20k-640x640.py \
        https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20220617_203542-940a6bd8.pth \
        --out-file outputs/B1_Segformer_ade20k.jpg \
        --device cuda:0 \
        --opacity 0.5

 

import os

import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt

# 指定单张图像路径
img_path = '/home/casic/mmsegmentation/data/Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = '/home/casic/mmsegmentation/data/Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'

img = cv2.imread(img_path)
mask = cv2.imread(mask_path)

# 可视化语义分割标注
plt.imshow(mask[:,:,0])
plt.show()

下载配置文件

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/DubaiDataset.py -P mmseg/datasets

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/__init__.py -P mmseg/datasets

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/DubaiDataset_pipeline.py -P configs/_base_/datasets

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/pspnet_r50-d8_4xb2-40k_DubaiDataset.py -P configs/pspnet 

 生成配置文件

from mmengine import Config
cfg = Config.fromfile('./configs/pspnet/pspnet_r50-d8_4xb2-40k_DubaiDataset.py')

cfg.norm_cfg = dict(type='BN', requires_grad=True) # 只使用GPU时,BN取代SyncBN
cfg.crop_size = (256, 256)
cfg.data_root = 'data/Watermelon87_Semantic_Seg_Mask'

cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = 6
cfg.model.auxiliary_head.num_classes = 6

cfg.train_dataloader.batch_size = 8
cfg.train_dataloader.dataset.data_root = 'data/Watermelon87_Semantic_Seg_Mask'
cfg.val_dataloader.dataset.data_root = 'data/Watermelon87_Semantic_Seg_Mask'

cfg.test_dataloader = cfg.val_dataloader

# 结果保存目录
cfg.work_dir = './work_dirs/DubaiDataset'

# 训练迭代次数
cfg.train_cfg.max_iters = 3000
# 评估模型间隔
cfg.train_cfg.val_interval = 400
# 日志记录间隔
cfg.default_hooks.logger.interval = 100
# 模型权重保存间隔
cfg.default_hooks.checkpoint.interval = 1500

# 随机数种子
cfg['randomness'] = dict(seed=0)

cfg.dump('pspnet-DubaiDataset_20230612.py')

 5.训练

import numpy as np

import os.path as osp
from tqdm import tqdm

import mmcv
import mmengine

from mmengine import Config
cfg = Config.fromfile('pspnet-DubaiDataset_20230612.py')

from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)

runner.train()

 6. 用训练的模型预测

import numpy as np
import matplotlib.pyplot as plt

from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2

# 载入 config 配置文件
from mmengine import Config
cfg = Config.fromfile('pspnet-DubaiDataset_20230612.py')

from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner

register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)

checkpoint_path = './work_dirs/DubaiDataset/iter_3000.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')

img = mmcv.imread('data/Watermelon87_Semantic_Seg_Labelme/images/04_35-2.jpg')
result = inference_model(model, img)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
plt.imshow(pred_mask)
plt.show()

 

...全文
27 回复 打赏 收藏 转发到动态 举报
AI 作业
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

535

社区成员

发帖
与我相关
我的任务
社区描述
构建国际领先的计算机视觉开源算法平台
社区管理员
  • OpenMMLab
  • jason_0615
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧