12,894
社区成员
发帖
与我相关
我的任务
分享在 MindSpore 框架中,nn.Cell是构建神经网络的基本单元,其作用类似于 PyTorch 中的 nn.Module或 TensorFlow 中的 tf.keras.Model。只要掌握“继承 Cell + 实现 construct”这一核心范式,即使是初学者也能高效搭建从简单到复杂的模型。
本文将通过清晰的步骤与两个实战案例(全连接网络 + 卷积网络),带你 10 分钟内快速入门!
nn.Cell的三大构建原则要使用 nn.Cell定义模型,只需牢记以下三点:
nn.Cell类nn.Cell的子类,这是 MindSpore 的基本规范。__init__中声明网络层construct中定义前向流程construct方法替代其他框架中的 forward或 call,用于描述数据如何流经各层。✨口诀:继承 Cell,init 定结构,construct 走数据。
多层感知机(MLP)是理解神经网络的起点。下面是一个适用于手写数字识别任务的简易 MLP 实现:
import mindspore.nn as nn
class SimpleMLP(nn.Cell):
def __init__(self):
super().__init__()
self.fc1 = nn.Dense(784, 256) # 输入784维 → 隐藏层256维
self.fc2 = nn.Dense(256, 128) # 隐藏层256维 → 128维
self.fc3 = nn.Dense(128, 10) # 输出10类(0-9)
self.relu = nn.ReLU()
def construct(self, x):
x = x.view(-1, 784) # 展平图像:(B, 28, 28) → (B, 784)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建并查看模型
model = SimpleMLP()
print("MLP 模型结构:", model)
nn.Dense(in, out):实现全连接映射;view(-1, 784):自动推断 batch 维度,将二维图像展平为一维向量;ReLU引入非线性,使网络具备拟合复杂函数的能力。对于图像任务,CNN 更具优势。下面是一个轻量级 LeNet-5 变体,适用于 MNIST 数据集:
import mindspore.nn as nn
class SimpleLeNet(nn.Cell):
def __init__(self):
super().__init__()
# 卷积部分
self.conv1 = nn.Conv2d(1, 6, 5) # 1通道 → 6通道,5×5卷积核
self.conv2 = nn.Conv2d(6, 16, 5) # 6通道 → 16通道
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 全连接部分
self.fc1 = nn.Dense(16 * 4 * 4, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4) # 展平特征图
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
cnn_model = SimpleLeNet()
print("CNN 模型结构:", cnn_model)
Conv2d参数顺序与 PyTorch 一致,降低迁移成本;MaxPool2d有效压缩空间维度,减少参数量;模型搭建完成后,可用随机张量快速验证是否运行正常:
import numpy as np
from mindspore import Tensor
import mindspore
# 模拟输入:32 张 1 通道 28×28 图像
x = Tensor(np.random.randn(32, 1, 28, 28), mindspore.float32)
# 前向推理(自动调用 construct)
output = cnn_model(x)
print("输入形状:", x.shape) # (32, 1, 28, 28)
print("输出形状:", output.shape) # (32, 10)
若输出形状符合预期且无报错,说明模型结构正确!
__init__中调用 super().__init__(),否则会引发错误。construct,写成 forward或 call会导致模型无法执行。16*4*4),建议中间插入 print(x.shape)调试。