unet图像分割结果缺少多个类

GISer_Lin 2020-08-10 08:55:34
使用torch写的unet训练结果缺少多个类的预测结果,请问是怎么回事



unet代码:
# -*- coding: utf-8 -*-
# @Time : 2020/6/15 19:47
# @Author : Zhao HL
# @File : unet.py
from my_utils.global_config import *
from .attention_module import *
from .rrm import *

# region 搭建网络

# region 编解码结构定义
# struct1:3*3conv,3*3conv
class Conv_33_a(Module):
def __init__(self, input_chs, output_chs):
super(Conv_33_a, self).__init__()

self.conv1 = Sequential(
Conv2d(input_chs, output_chs, 3, padding=1),
BatchNorm2d(output_chs),
ReLU(),
)
self.conv2 = Sequential(
Conv2d(output_chs, output_chs, 3, padding=1),
BatchNorm2d(output_chs),
ReLU(),
)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x

class Conv_33_b(Module):
def __init__(self, input_chs, output_chs):
super(Conv_33_b, self).__init__()

self.conv1 = Sequential(
Conv2d(input_chs, input_chs, 3, padding=1),
BatchNorm2d(input_chs),
ReLU(),
)
self.conv2 = Sequential(
Conv2d(input_chs, output_chs, 3, padding=1),
BatchNorm2d(output_chs),
ReLU(),
)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x





class Conv_33_mid(Module):
def __init__(self, input_chs, output_chs):
super(Conv_33_mid, self).__init__()
self.conv1 = Sequential(
Conv2d(input_chs, Mid_chs, 3, padding=1),
BatchNorm2d(Mid_chs),
ReLU(),
)
self.conv2 = Sequential(
Conv2d(Mid_chs, output_chs, 3, padding=1),
BatchNorm2d(output_chs),
ReLU(),
)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
# endregion

# region 编码、解码、桥接结构定义
class Conv_Group_En(Conv_33_a):
def __init__(self, input_chs, output_chs, ):
super(Conv_Group_En, self).__init__(input_chs, output_chs, )


class Conv_Group_De(Conv_33_b):
def __init__(self, input_chs, output_chs, ):
super(Conv_Group_De, self).__init__(input_chs, output_chs, )


class Conv_Group_Mid(Conv_33_mid):
def __init__(self, input_chs, output_chs):
super(Conv_Group_Mid, self).__init__(input_chs, output_chs, )


class Bridge_connection(Module):
def __init__(self, input_chs, output_chs):
# super(Bridge_connection, self).__init__(input_chs, output_chs)
super(Bridge_connection,self).__init__()
def forward(self, x):
return x
# endregion

class Unet(Module):
def __init__(self):
super(Unet, self).__init__()

mid_in_chs, mid_out_chs = Layer4_chs, Layer4_chs

self.en_conv1 = Conv_Group_En(Img_chs, Layer1_chs)
self.en_conv2 = Conv_Group_En(Layer1_chs, Layer2_chs)
self.en_conv3 = Conv_Group_En(Layer2_chs, Layer3_chs)
self.en_conv4 = Conv_Group_En(Layer3_chs, Layer4_chs)

self.en_down1 = MaxPool2d(2, stride=2)
self.en_down2 = MaxPool2d(2, stride=2)
self.en_down3 = MaxPool2d(2, stride=2)
self.en_down4 = MaxPool2d(Mid_resize, stride=Mid_resize)

self.mid_layer = Conv_Group_Mid(mid_in_chs, mid_out_chs)
self.bc4 = Bridge_connection(Layer4_chs,Layer4_chs)
self.bc3 = Bridge_connection(Layer3_chs,Layer3_chs)
self.bc2 = Bridge_connection(Layer2_chs,Layer2_chs)
self.bc1 = Bridge_connection(Layer1_chs,Layer1_chs)

self.de_up4 = ConvTranspose2d(Layer4_chs + mid_out_chs, Layer4_chs, kernel_size=Mid_resize, stride=Mid_resize)
self.de_up3 = ConvTranspose2d(Layer3_chs * 2, Layer3_chs, kernel_size=2, stride=2)
self.de_up2 = ConvTranspose2d(Layer2_chs * 2, Layer2_chs, kernel_size=2, stride=2)
self.de_up1 = ConvTranspose2d(Layer1_chs * 2, Layer1_chs, kernel_size=2, stride=2)

self.de_conv4 = Conv_Group_De(Layer4_chs, Layer3_chs)
self.de_conv3 = Conv_Group_De(Layer3_chs, Layer2_chs)
self.de_conv2 = Conv_Group_De(Layer2_chs, Layer1_chs)
self.de_conv1 = Conv_Group_De(Layer1_chs, Labels_nums)
self.act_output = Softmax()
#
# self.de_conv1 = Conv_Group_De(Layer1_chs, Layer1_chs)
# self.act_output = RRM_Bu(Layer1_chs, Labels_nums)

def forward(self, input): # size * size * img_chs
ec1 = self.en_conv1(input) # size * size * ch1
ed1 = self.en_down1(ec1) # size/2 * size/2 * ch1
ec2 = self.en_conv2(ed1) # size/2 * size/2 * ch2
ed2 = self.en_down2(ec2) # size/4 * size/4 * ch2
ec3 = self.en_conv3(ed2) # size/4 * size/4 * ch3
ed3 = self.en_down3(ec3) # size/8 * size/8 * ch3
ec4 = self.en_conv4(ed3) # size/8 * size/8 * ch4
ed4 = self.en_down4(ec4) # size/8/x * size/8/x * ch4

ml = self.mid_layer(ed4) # size/8/x * size/8/x * ch_mid

#711,111, 579.87
# cat4 = self.bc4(ed4) # size/8/x * size/8/x * ch4
# cat3 = self.bc3(ed3) # size/8 * size/8 * ch3
# cat2 = self.bc2(ed2) # size/4 * size/4 * ch2
# cat1 = self.bc1(ed1) # size/2 * size/2 * ch1
#711, 122 571.22
cat4 = ed4 # size/8/x * size/8/x * ch4
cat3 = ed3 # size/8 * size/8 * ch3
cat2 = ed2 # size/4 * size/4 * ch2
cat1 = ed1 # size/2 * size/2 * ch1

dm4 = torch.cat([cat4, ml], dim=1) # size/8/x * size/8/x * ch_mid+ch4
du4 = self.de_up4(dm4) # size/8 * size/8 * ch4
dc4 = self.de_conv4(du4) # size/8 * size/8 * ch3

dm3 = torch.cat([cat3, dc4], dim=1) # size/8 * size/8 * ch3+ch3
du3 = self.de_up3(dm3) # size/4 * size/4 * ch3
dc3 = self.de_conv3(du3) # size/4 * size/4 * ch2

dm2 = torch.cat([cat2, dc3], dim=1) # size/4 * size/4 * ch2+ch2
du2 = self.de_up2(dm2) # size/2 * size/2 * ch2
dc2 = self.de_conv2(du2) # size/2 * size/2 * ch1

dm1 = torch.cat([cat1, dc2], dim=1) # size/2 * size/2 * ch1+ch1
du1 = self.de_up1(dm1) # size * size * ch1
dc1 = self.de_conv1(du1) # size * size * cls

output = self.act_output(dc1)

return output

# endregion
...全文
317 回复 打赏 收藏 转发到动态 举报
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

37,719

社区成员

发帖
与我相关
我的任务
社区描述
JavaScript,VBScript,AngleScript,ActionScript,Shell,Perl,Ruby,Lua,Tcl,Scala,MaxScript 等脚本语言交流。
社区管理员
  • 脚本语言(Perl/Python)社区
  • IT.BOB
加入社区
  • 近7日
  • 近30日
  • 至今

试试用AI创作助手写篇文章吧