深度学习之图像识别|“朝闻道”知识分享大赛

hichilde 2024-12-31 17:26:02

这是我参加“朝闻道”知识分享大赛的第6篇文章

基于卷积神经网络的图像识别

收集狗、鼠、兔、鸡四类图像,并微调resnet18或者其他网络进行微调,达到对上述四种类别图像的分类,其他图像归并为第5类
模型代码:


import torch.nn as nn
from torchvision.models import resnet18
class CustomResNet18(nn.Module):
   def __init__(self, num_classes=5):
       super(CustomResNet18, self).__init__()
       self.model = resnet18(pretrained=True)
       self.model.fc = nn.Sequential(
           nn.Dropout(0.2),
           nn.Linear(512, num_classes),
       )
   def forward(self, x):
       return self.model(x)

resnet18网络拓扑图

img

ResNet18作为基础网络进行微调。微调是一种有效的迁移学习手段,可以节省训练时间并提高模型性能。通过清洗数据集,移除噪声和错误标签的样本,模型的性能有了显著提升。

训练集代码:

# 数据预处理和数据增强
transform = transforms.Compose([
   # transforms.RandomResizedCrop(224),  # 随机裁剪并缩放到224x224
   transforms.Resize(256),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])
# 数据加载
train_dataset = datasets.ImageFolder('训练集', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = datasets.ImageFolder('测试集', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

num_classes = len(train_dataset.classes)
model = CustomResNet18(num_classes).to(device)
loss_fun = torch.nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.05)
# 训练模型
num_epochs = 20
best_loss = float("inf")

for epoch in range(num_epochs):
   model.train()  # 设置模型为训练模式
   running_loss = 0.0
   correct = 0
   total = 0

   for images, labels in train_loader:
       images, labels = images.to(device), labels.to(device)
       outputs = model(images)
       loss = loss_fun(outputs, labels)
       optimizer.zero_grad()  # 清除梯度
       loss.backward()   # 反向传播
       optimizer.step()  # 更新参数
       running_loss += loss.item()
       predicted = torch.max(outputs, 1)[1]
       total += labels.size(0)
       correct += (predicted == labels).sum().item()
       print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
   model.eval()
   test_loss = 0.0
   test_correct = 0
   test_total = 0

   for images_test, labels_test in test_loader:
       images1, labels1 = images_test.to(device), labels_test.to(device)
       outputs_test = model(images1)
       correct_loss = loss_fun(outputs_test, labels1)
       test_loss += correct_loss.item()
       predicted = torch.max(outputs_test, 1)[1]
       test_total += labels1.size(0)
       test_correct += (predicted == labels1).sum().item()
   if test_loss<best_loss:
       best_loss=test_loss
       print(f" Loss: {best_loss / len(test_loader):.4f}, Accuracy: {100 * test_correct / test_total: .2f}%")
       # 保存模型
       torch.save(model, r" \logs\imagefolder.pth")

...全文
46 回复 打赏 收藏 转发到动态 举报
AI 作业
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

1,040

社区成员

发帖
与我相关
我的任务
社区描述
中南民族大学CSDN高校俱乐部聚焦校内IT技术爱好者,通过构建系统化的内容和运营体系,旨在将中南民族大学CSDN社区变成校内最大的技术交流沟通平台。
经验分享 高校 湖北省·武汉市
社区管理员
  • c_university_1575
  • WhiteGlint666
  • wzh_scuec
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告

欢迎各位加入中南民族大学&&CSDN高校俱乐部社区(官方QQ群:908527260),成为CSDN高校俱乐部的成员具体步骤(必填),填写如下表单,表单链接如下:
人才储备数据库及线上礼品发放表单邀请人吴钟昊:https://ddz.red/CSDN
CSDN高校俱乐部是给大家提供技术分享交流的平台,会不定期的给大家分享CSDN方面的相关比赛以及活动或实习报名链接,希望大家一起努力加油!共同建设中南民族大学良好的技术知识分享社区。

注意:

1.社区成员不得在社区发布违反社会主义核心价值观的言论。

2.社区成员不得在社区内谈及政治敏感话题。

3.该社区为知识分享的平台,可以相互探讨、交流学习经验,尽量不在社区谈论其他无关话题。

试试用AI创作助手写篇文章吧