159
社区成员
mindspore.nn
类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell
类,并重写__init__
方法和construct
方法。__init__
包含所有网络层的定义,construct
中包含数据(Tensor)的变换过程(即计算图的构造过程)。
# Define model class Network(nn.Cell): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.dense_relu_sequential = nn.SequentialCell( nn.Dense(28*28, 512), nn.ReLU(), nn.Dense(512, 512), nn.ReLU(), nn.Dense(512, 10) ) def construct(self, x): x = self.flatten(x) logits = self.dense_relu_sequential(x) return logits model = Network() print(model)
Network< (flatten): Flatten<> (dense_relu_sequential): SequentialCell< (0): Dense<input_channels=784, output_channels=512, has_bias=True> (1): ReLU<> (2): Dense<input_channels=512, output_channels=512, has_bias=True> (3): ReLU<> (4): Dense<input_channels=512, output_channels=10, has_bias=True> > >
更多细节详见网络构建。