MMSegmentation实战操作(保姆级过程)

抹茶味的糯米糍 2023-06-17 15:12:24

视频来源:MMSegmentation代码课

安装配置mmsegmentation

配置资源和环境

终端命令行下载最新源码:git clone https://github.com/open-mmlab/mmsegmentation.git -b dev-1.x
终端命令行进入目录:cd mmsegmetation
终端命令行安装 MMSegmentation:pip install -v -e .

##检查和测试

检查

# 检查 mmsegmentation
#没有报错,即证明安装成功。
import mmseg
from mmseg.utils import register_all_modules
from mmseg.apis import inference_model, init_model
print('mmsegmentation版本', mmseg.__version__)

测试

在mmsegmetation目录下创建三个子目录:

  1. 创建checkpoint 文件夹,用于存放预训练模型权重文件

cityscapes上的PSPNet

cityscapes上的SegFormer

cityscapes上的Mask2Former

  1. 创建 outputs 文件夹,用于存放预测结果
  1. 创建 data 文件夹,用于存放图片和视频素材

下载素材至data目录——

伦敦街景图片

上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk

街拍视频

预训练语义分割模型预测-单张图像(伦敦街景图片)-命令行

img

命令行转到mmsegmentation目录下,在命令行中键入下述命令,opacity是可视化透明度

在cityscapes数据集上预训练的PSPNet

python demo/image_demo.py data/street_uk.jpeg configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --out-file outputs/B1_uk_pspnet.jpg --device cuda:0 --opacity 0.5

img

在cityscapes数据集上预训练的segformer

python demo/image_demo.py data/street_uk.jpeg configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth --out-file outputs/B1_uk_segformer.jpg --device cuda:0 --opacity 0.5

img

在cityscapes数据集上预训练的mask2fomer

python demo/image_demo.py data/street_uk.jpeg configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth --out-file outputs/B1_uk_Mask2Former.jpg --device cuda:0 --opacity 0.5

img

在ADE20K数据集上预训练的segformer

python demo/image_demo.py data/street_uk.jpeg configs/segformer/segformer_mit-b5_8xb2-160k_ade20k-640x640.py checkpoint/segformer_mit-b5_640x640_160k_ade20k_20220617_203542-940a6bd8.pth --out-file outputs/B1_Segformer_ade20k.jpg --device cuda:0 --opacity 0.5

img

预训练语义分割模型预测-视频

视频预测-命令行(不推荐,慢)

命令行转到mmsegmentation目录下,在命令行中键入下述命令,opacity是可视化透明度

python demo/video_demo.py data/street_20220330_174028.mp4 configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth --device cuda:0 --output-file outputs/B3_video.mp4 --opacity 0.5

视频预测-Python API(推荐,快)

基于Cityscapes 街景数据集预训练mask2former模型

# -*- coding = utf-8 -*-
# @Time : 2023/6/17 13:17
# @Author : Happiness
# @Software : PyCharm



####导入工具包

import os
import numpy as np
import time
import shutil

import torch

from PIL import Image
import cv2

import mmcv
import mmengine
from mmseg.apis import inference_model
from mmseg.utils import register_all_modules
register_all_modules()

from mmseg.datasets import CityscapesDataset




####载入模型

# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'

# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'
from mmseg.apis import init_model
model = init_model(config_file, checkpoint_file, device='cuda:0')

from mmengine.model.utils import revert_sync_batchnorm
if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)



#### 输入视频路径

input_video = 'data/traffic.mp4'

#input_video = 'data/street_20220330_174028.mp4'





#创建临时文件夹,存放每帧结果


temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))




#####视频单帧预测

# 获取 Cityscapes 街景数据集 类别名和调色板
from mmseg.datasets import cityscapes

classes = cityscapes.CityscapesDataset.METAINFO['classes']
palette = cityscapes.CityscapesDataset.METAINFO['palette']


def pridict_single_frame(img, opacity=0.2):
    result = inference_model(model, img)

    # 将分割图按调色板染色
    seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
    seg_img = Image.fromarray(seg_map).convert('P')
    seg_img.putpalette(np.array(palette, dtype=np.uint8))

    show_img = (np.array(seg_img.convert('RGB'))) * (1 - opacity) + img * opacity

    return show_img





##### 视频逐帧预测.

# 读入待预测视频
imgs = mmcv.VideoReader(input_video)

prog_bar = mmengine.ProgressBar(len(imgs))

# 对视频逐帧处理
for frame_id, img in enumerate(imgs):
    ## 处理单帧画面
    show_img = pridict_single_frame(img, opacity=0.15)
    temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg'  # 保存语义分割预测结果图像至临时文件夹
    cv2.imwrite(temp_path, show_img)

    prog_bar.update()  # 更新进度条

# 把每一帧串成视频文件
#mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')
mmcv.frames2video(temp_out_dir, 'outputs/B4_video.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir)  # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)


Kaggle实战-小鼠肾小球组织病理切片语义分割

数据集
解压到数据集data目录下

数据集配置文件
保存路径为/mmsegmentation/mmseg/datasets/StanfordBackgroundDataset.py
如果操作了这一步,下一小节配置文件第一步就不用复制了

修改 ../mmsegmentation/mmseg/datasets/init.py,添加数据集
保存路径维/mmsegmentation/mmseg/datasets/init.py会替换掉原有文件

划分训练集和测试集

# -*- coding = utf-8 -*-
# @Time : 2023/6/17 17:57
# @Author : Happiness
# @File : split.py
# @Software : PyCharm




####导入工具包

import os
import random



#获取全部数据文件名列表
PATH_IMAGE = 'D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/images'
all_file_list = os.listdir(PATH_IMAGE)
all_file_num = len(all_file_list)
random.shuffle(all_file_list) # 随机打乱全部数据文件名列表



###指定训练集和测试集比例
train_ratio = 0.8
test_ratio = 1 - train_ratio
train_file_list = all_file_list[:int(all_file_num*train_ratio)]
test_file_list = all_file_list[int(all_file_num*train_ratio):]
print('数据集图像总数', all_file_num)
print('训练集划分比例', train_ratio)
print('训练集图像个数', len(train_file_list))
print('测试集图像个数', len(test_file_list))




train_file_list[:5]
['SAS_21883_001_35.png',
 'VUHSK_1352_59.png',
 'SAS_21908_001_60.png',
 'SESCAM_9_0_25.png',
 'SAS_21896_001_26.png']
test_file_list[:5]
['VUHSK_1272_101.png',
 'SAS_21937_001_117.png',
 'VUHSK_1502_11.png',
 'SAS_21904_001_3.png',
 'VUHSK_1502_8.png']



###生成两个txt划分文件
with open('D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/splits/train.txt', 'w') as f:
    f.writelines(line.split('.')[0] + '\n' for line in train_file_list)
with open('D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/splits/val.txt', 'w') as f:
    f.writelines(line.split('.')[0] + '\n' for line in test_file_list)

配置文件

如果操作了上一节StanfordBackgroundDataset文件的下载,下一步就不用复制了,直接代码块

将下面这段代码原封不动复制到文件\mmsegmentation\mmseg\datasets\basesegdataset.py

@DATASETS.register_module()
class StanfordBackgroundDataset(BaseSegDataset):
METAINFO = dict(classes = ('background', 'glomeruili'), palette = [[128, 128, 128], [151, 189, 8]])
def init(self, **kwargs):
super().init(img_suffix='.png', seg_map_suffix='.png', **kwargs)

# -*- coding = utf-8 -*-
# @Time : 2023/6/17 18:25
# @Author : Happiness
# @File : newconfig.py
# @Software : PyCharm




######导入工具包

import numpy as np
from PIL import Image

import os.path as osp
from tqdm import tqdm

import mmcv
import mmengine
import matplotlib.pyplot as plt



# 数据集图片和标注路径
data_root = 'D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset'
img_dir = 'images'
ann_dir = 'masks'

# 类别和对应的颜色
classes = ('background', 'glomeruili')
palette = [[128, 128, 128], [151, 189, 8]]



####修改数据集类(指定图像扩展名)

#After downloading the data, we need to implement load_annotations function in the new dataset class StanfordBackgroundDataset.

from mmseg.registry import DATASETS
from mmseg.datasets import BaseSegDataset

# @DATASETS.register_module()
# class StanfordBackgroundDataset(BaseSegDataset):
#   METAINFO = dict(classes = classes, palette = palette)
#   def __init__(self, **kwargs):
#     super().__init__(img_suffix='.png', seg_map_suffix='.png', **kwargs)

#文档:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/tutorials/customize_datasets.md#customize-datasets-by-reorganizing-data




####修改config配置文件

# 下载 config 文件 和 预训练模型checkpoint权重文件


from mmengine import Config
cfg = Config.fromfile('D:/0.dive into pytorch/openmmlab/mmsegmentation/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')
cfg.norm_cfg = dict(type='BN', requires_grad=True) # 只使用GPU时,BN取代SyncBN
cfg.crop_size = (256, 256)
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 2
cfg.model.auxiliary_head.num_classes = 2

# 修改数据集的 type 和 root
cfg.dataset_type = 'StanfordBackgroundDataset'
cfg.data_root = data_root

cfg.train_dataloader.batch_size = 8

cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='RandomResize', scale=(320, 240), ratio_range=(0.5, 2.0), keep_ratio=True),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackSegInputs')
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(320, 240), keep_ratio=True),
    # add loading annotation after ``Resize`` because ground truth
    # does not need to do resize data transform
    dict(type='LoadAnnotations'),
    dict(type='PackSegInputs')
]


###3修改成绝对路径

cfg.train_dataloader.dataset.data_root='D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/'
cfg.val_dataloader.dataset.data_root='D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/'
cfg.test_dataloader.dataset.data_root='D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/'


cfg.train_dataloader.dataset.type = cfg.dataset_type
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.train_dataloader.dataset.data_prefix = dict(img_path=img_dir, seg_map_path=ann_dir)
cfg.train_dataloader.dataset.pipeline = cfg.train_pipeline
cfg.train_dataloader.dataset.ann_file = 'splits/train.txt'

cfg.val_dataloader.dataset.type = cfg.dataset_type
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_prefix = dict(img_path=img_dir, seg_map_path=ann_dir)
cfg.val_dataloader.dataset.pipeline = cfg.test_pipeline
cfg.val_dataloader.dataset.ann_file = 'splits/val.txt'

cfg.test_dataloader = cfg.val_dataloader


# 载入预训练模型权重
cfg.load_from = 'D:/0.dive into pytorch/openmmlab/mmsegmentation/checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

# 工作目录
cfg.work_dir = './work_dirs/tutorial'

# 训练迭代次数
cfg.train_cfg.max_iters = 800
# 评估模型间隔
cfg.train_cfg.val_interval = 400
# 日志记录间隔
cfg.default_hooks.logger.interval = 100
# 模型权重保存间隔
cfg.default_hooks.checkpoint.interval = 400

# 随机数种子
cfg['randomness'] = dict(seed=0)



####保存config配置文件
cfg.dump('Glomeruli_pspnet_cityscapes.py')




训练

# -*- coding = utf-8 -*-
# @Time : 2023/6/18 12:27
# @Author : Happiness
# @File : train.py
# @Software : PyCharm



####载入config配置文件

from mmengine import Config
cfg = Config.fromfile('Glomeruli_pspnet_cityscapes.py')



####准备训练

from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)



###开始训练

if __name__=='__main__':#不加这个会报错,多线程错误
    runner.train()

测试

终端命令行先转到当前项目文件Glomeruli下

测试集精度指标

python ../../tools/test.py Glomeruli_pspnet_cityscapes.py ./work_dirs/tutorial/iter_800.pth

img

测试集速度指标

python ../../tools/analysis_tools/benchmark.py Glomeruli_pspnet_cityscapes.py ./work_dirs/tutorial/iter_800.pth

img

推理图片



# -*- coding = utf-8 -*-
# @Time : 2023/6/18 16:29
# @Author : Happiness
# @File : tuili.py
# @Software : PyCharm


#####用训练得到的模型预测


####3导入工具包

import numpy as np
import matplotlib.pyplot as plt

from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2



####载入模型

# 载入 config 配置文件
from mmengine import Config
cfg = Config.fromfile('Glomeruli_pspnet_cityscapes.py')
from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner

register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)

# 初始化模型
checkpoint_path = './work_dirs/tutorial/iter_800.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')

####载入测试集图像,或新图像

img = mmcv.imread('D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Glomeruli-dataset/images/SAS_21883_001_6.png')



####语义分割预测

result = inference_model(model, img)
result.keys()
['seg_logits', 'pred_sem_seg']
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()



#####可视化语义分割预测结果

plt.imshow(pred_mask)
plt.show()



# 可视化预测结果

visualization = show_result_pyplot(model, img, result, opacity=0.7, out_file='pred.jpg')
plt.imshow(mmcv.bgr2rgb(visualization))
plt.show()



###语义分割预测结果-连通域分析

plt.imshow(np.uint8(pred_mask))
plt.show()


connected = cv2.connectedComponentsWithStats(np.uint8(pred_mask), connectivity=4)

plt.imshow(connected[1])
plt.show()


####3获取测试集标注

label = mmcv.imread('Glomeruli-dataset/masks/VUHSK_1702_39.png')
label_mask = label[:,:,0]

plt.imshow(label_mask)
plt.show()


####对比测试集标注和语义分割预测结果


# 真实为前景,预测为前景
TP = (label_mask == 1) & (pred_mask==1)
# 真实为背景,预测为背景
TN = (label_mask == 0) & (pred_mask==0)
# 真实为前景,预测为背景
FN = (label_mask == 1) & (pred_mask==0)
# 真实为背景,预测为前景
FP = (label_mask == 0) & (pred_mask==1)
plt.imshow(TP)
plt.show()

confusion_map = TP * 255 + FP * 150 + FN * 80 + TN * 10
plt.imshow(confusion_map)
plt.show()



####混淆矩阵

from sklearn.metrics import confusion_matrix

confusion_matrix_model = confusion_matrix(label_map.flatten(), pred_mask.flatten())
import itertools


def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
    """
    传入混淆矩阵和标签名称列表,绘制混淆矩阵
    """
    plt.figure(figsize=(10, 10))

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # plt.colorbar() # 色条
    tick_marks = np.arange(len(classes))

    plt.title('Confusion Matrix', fontsize=30)
    plt.xlabel('Pred', fontsize=25, c='r')
    plt.ylabel('True', fontsize=25, c='r')
    plt.tick_params(labelsize=16)  # 设置类别文字大小
    plt.xticks(tick_marks, classes, rotation=90)  # 横轴文字旋转
    plt.yticks(tick_marks, classes)

    # 写数字
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > threshold else "black",
                 fontsize=12)

    plt.tight_layout()

    plt.savefig('混淆矩阵.pdf', dpi=300)  # 保存图像
    plt.show()


classes = ('background', 'glomeruili')
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')

Kaggle实战-迪拜卫星航拍多类别语义分割

命令行进入到mmsegmentation目录下:cd mmsegmentation

准备config配置文件

Kaggle原数据

整理好的数据

定义数据集类(各类别名称及配色)
定义数据集类(各类别名称及配色)存放在mmseg/datasets/DubaiDataset.py目录中

注册数据集类
注册数据集类存放在mmseg/datasets/init.py目录下(会替代掉原有文件)

定义训练及测试pipeline
定义训练及测试pipeline存放在configs/base/datasets/DubaiDataset_pipeline.py目录下

下载模型config配置文件

保存 下载模型config配置文件的路径“configs/pspnet/pspnet_r50-d8_4xb2-40k_DubaiDataset.py”

# -*- coding = utf-8 -*-
# @Time : 2023/6/17 23:24
# @Author : Happiness
# @File : newconfig.py
# @Software : PyCharm




#####导入工具包

import numpy as np
from PIL import Image

import os.path as osp
from tqdm import tqdm

import mmcv
import mmengine
import matplotlib.pyplot as plt




####载入config配置文件

from mmengine import Config
cfg = Config.fromfile('../../configs/pspnet/pspnet_r50-d8_4xb2-40k_DubaiDataset.py')






#####修改config配置文件
cfg.norm_cfg = dict(type='BN', requires_grad=True) # 只使用GPU时,BN取代SyncBN
cfg.crop_size = (256, 256)
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = 6
cfg.model.auxiliary_head.num_classes = 6

cfg.train_dataloader.batch_size = 8
cfg.train_dataloader.dataset.data_root='D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Dubai-dataset/'####换成你自己的绝对路径
cfg.val_dataloader.dataset.data_root='D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Dubai-dataset/'####换成你自己的绝对路径
cfg.test_dataloader.dataset.data_root='D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Dubai-dataset/'####换成你自己的绝对路径
cfg.test_dataloader = cfg.val_dataloader

# 结果保存目录
cfg.work_dir = './work_dirs/DubaiDataset'

# 训练迭代次数
cfg.train_cfg.max_iters = 3000
# 评估模型间隔
cfg.train_cfg.val_interval = 400
# 日志记录间隔
cfg.default_hooks.logger.interval = 100
# 模型权重保存间隔
cfg.default_hooks.checkpoint.interval = 1500

# 随机数种子
cfg['randomness'] = dict(seed=0)



####查看完整config配置文件

# print(cfg.pretty_text)



#3####保存config配置文件



cfg.dump('pspnet-DubaiDataset.py')

记得换路径

训练

# -*- coding = utf-8 -*-
# @Time : 2023/6/17 23:42
# @Author : Happiness
# @File : train.py
# @Software : PyCharm




####导入工具包

import numpy as np

import os.path as osp
from tqdm import tqdm

import mmcv
import mmengine



####载入config配置文件

from mmengine import Config
cfg = Config.fromfile('pspnet-DubaiDataset.py')



####准备训练

from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)



###开始训练
if __name__=='__main__':
    runner.train()


测试集评估指标

终端cmd跳转到mmsegmentation目录文件下——cd mmsegmentation

测试集精度指标

python tools/test.py projects/dubai/pspnet-DubaiDataset.py projects/dubai/work_dirs/DubaiDataset/iter_3000.pth

img

用训练得到的模型预测


```python
# -*- coding = utf-8 -*-
# @Time : 2023/6/18 11:49
# @Author : Happiness
# @File : tuili.py
# @Software : PyCharm


####导入工具包

import numpy as np
import matplotlib.pyplot as plt



from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2




####载入配置文件

# 载入 config 配置文件
from mmengine import Config
cfg = Config.fromfile('pspnet-DubaiDataset.py')
from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner

register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)


###载入模型

checkpoint_path = './work_dirs/DubaiDataset/iter_3000.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')




###载入测试集图像,或新图像

#使用绝对路径不怕报错
img = mmcv.imread('D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Dubai-dataset/img_dir/val/3.jpg')




###语义分割预测

result = inference_model(model, img)
['seg_logits', 'pred_sem_seg']
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()



###可视化语义分割预测结果

plt.imshow(pred_mask)
plt.show()


# 可视化预测结果
visualization = show_result_pyplot(model, img, result, opacity=0.7, out_file='pred.jpg')
plt.imshow(mmcv.bgr2rgb(visualization))
plt.show()



#####获取测试集标注

label = mmcv.imread('D:/0.dive into pytorch/openmmlab/mmsegmentation/data/Dubai-dataset/img_dir/val/3.jpg')

label_mask = label[:,:,0]
plt.imshow(label_mask)
plt.show()



####对比测试集标注和语义分割预测结果
# 测试集标注


# 语义分割预测结果

# 真实为前景,预测为前景
TP = (label_mask == 1) & (pred_mask==1)
# 真实为背景,预测为背景
TN = (label_mask == 0) & (pred_mask==0)
# 真实为前景,预测为背景
FN = (label_mask == 1) & (pred_mask==0)
# 真实为背景,预测为前景
FP = (label_mask == 0) & (pred_mask==1)
plt.imshow(TP)
plt.show()

confusion_map = TP * 255 + FP * 150 + FN * 80 + TN * 30
plt.imshow(confusion_map)
plt.show()



####混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix_model = confusion_matrix(label_mask.flatten(), pred_mask.flatten())

import itertools


def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
    """
    传入混淆矩阵和标签名称列表,绘制混淆矩阵
    """
    plt.figure(figsize=(10, 10))

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # plt.colorbar() # 色条
    tick_marks = np.arange(len(classes))

    plt.title('Confusion Matrix', fontsize=30)
    plt.xlabel('Pred', fontsize=25, c='r')
    plt.ylabel('True', fontsize=25, c='r')
    plt.tick_params(labelsize=16)  # 设置类别文字大小
    plt.xticks(tick_marks, classes, rotation=90)  # 横轴文字旋转
    plt.yticks(tick_marks, classes)

    # 写数字
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > threshold else "black",
                 fontsize=12)

    plt.tight_layout()

    plt.savefig('混淆矩阵.pdf', dpi=300)  # 保存图像
    plt.show()


classes = ['Land', 'Road', 'Building', 'Vegetation', 'Water', 'Unlabeled']
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')



...全文
1130 回复 打赏 收藏 转发到动态 举报
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

532

社区成员

发帖
与我相关
我的任务
社区描述
构建国际领先的计算机视觉开源算法平台
社区管理员
  • OpenMMLab
  • jason_0615
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧