4,446
社区成员
发帖
与我相关
我的任务
分享
class MNISTDataset(Dataset):
#初始化数据集
def __init__(self,root,is_train=True):
self.dataset = []
sub_dir = "Train" if is_train else "Test"
for tag in os.listdir((f"{root}/{sub_dir}")):
img_dir = f"{root}/{sub_dir}/{tag}"
for img_filename in os.listdir(img_dir):
img_path = f"{img_dir}/{img_filename}"
self.dataset.append(img_path,tag)
#数据集有多少数据
def __len__(self):
return len(self.dataset)
#每条数据的处理方式
def __getitem__(self, index):
data = self.dataset[index]
#处理图像
img_data = cv2.imread(data[0],cv2.IMREAD_GRAYSCALE)
#数据变为一维数据(数据展平),配套后面使用全连接使用
img_data = img_data.reshape(-1)
#归一化0~1
img_data = img_data/255
#对标签做one-hot处理
tag = torch.zero(10)
tag[int(data[1])]= 1
#pytorch默认数据是 float32
return np.float32(img_data),np.float32(tag)
重载我上面的函数。
另外,我自己训练MTCNN是用Cebela 。你网上搜下这个数据集怎么训练。 MTCNN,照理说O网络训练要比较久。我自己用Cebela ,损失基于:人脸位置,类别+五个关键点。实际是P网络要跑 20个迭代,R网络五六个,O网络不到一个迭代就过拟合。