使用LSTM处理新闻标题分类任务(2)之模型训练 | “朝闻道”知识分享大赛

九酿丸子 2023-12-14 20:54:12

这是我参加朝闻道知识分享大赛的第 6 篇文章

目录

权重的初始化

模型训练

定义训练日志类

自定义训练函数

 自定义测试函数

自定义训练主函数

训练FastText模型

 训练TextCNN模型

 训练LSTM模型

 模型的预测

权重的初始化

权重初始化在深度学习中至关重要,它的目标是在训练开始时为神经网络提供合适的初始权重,以促使网络更容易学习。与简单的随机初始化相比,使用一些特殊的初始化方法(如Xavier、Kaiming)可以带来一些优点,这些优点有助于训练深度神经网络。Xavier和Kaiming初始化方法设计得更加智能,考虑了激活函数的非线性性质。它们有助于使每一层的输出的方差保持稳定,防止梯度在反向传播时过度增长或缩小,从而增加了整个模型的稳定性。

# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        # 不对词嵌入的层的参数进行初始化,因为我们使用的是预训练的模型
        if exclude not in name:
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass

模型训练

定义训练日志类

模型在训练的过程中,我们为了监测模型每个batch的训练效果,我们一般会计算相应的指标【例如:损失函数的损失值,精度等等】,然后根据指标的变化情况灵活调整训练策略。一般的做法是直接将指标打印到控制台,但是这样无法保存模型的训练历史,对于后续要对多个模型之间进行对比时,无法提供直接数据支撑。在此,采用自定义类的方式,将模型训练过程输出到日志文件中。

自定义训练函数

def train(model, device, dataloader, criterion, optimizer, epoch):
    model.train()  # 声明模型要开始训练,需要计算参数梯度
    model.to(device)  # 将模型放入对应的计算设备
    for img, label in tqdm(dataloader, f"Epoch {epoch}"):
        pred = model(img.to(device))  # 使用模型去计算预测值
        loss = criterion(pred, label.to(device))  # 计算误差
        
        loss.backward()  # 反向传播计算误差梯度
        optimizer.step()  # 使用梯度更新参数值
        optimizer.zero_grad()  # 将参数的梯度置零

 自定义测试函数

def test(model, device, dataloader, epoch):
    model.eval()  # 设置模型模式为验证模式,禁止去计算参数的梯度
    model.to(device)
    total = 0  # 统计预测正确的数量
    for img, label in tqdm(dataloader, f"Epoch {epoch}"):
        pred = model(img.to(device))
        pred = F.softmax(pred, dim=1).argmax(dim=1)
        total += torch.sum(pred == label.to(device))
    acc = total / len(test_dataloader.dataset)
    print(f'Epoch {epoch}, acc: {acc}')

自定义训练主函数

def train_model(model, epochs=100):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(1, epochs+1):
        train(model, device, train_dataloader, criterion, optimizer, epoch)
        test(model, device, test_dataloader, epoch)

训练FastText模型

 训练TextCNN模型

 训练LSTM模型

 模型的预测

 经过前面的模型训练结果,我们发现针对新闻文本分类任务,LSTM是最佳模型,故使用该模型完成模型验证工作。

 

...全文
38 回复 打赏 收藏 转发到动态 举报
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

1,040

社区成员

发帖
与我相关
我的任务
社区描述
中南民族大学CSDN高校俱乐部聚焦校内IT技术爱好者,通过构建系统化的内容和运营体系,旨在将中南民族大学CSDN社区变成校内最大的技术交流沟通平台。
经验分享 高校 湖北省·武汉市
社区管理员
  • c_university_1575
  • WhiteGlint666
  • wzh_scuec
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告

欢迎各位加入中南民族大学&&CSDN高校俱乐部社区(官方QQ群:908527260),成为CSDN高校俱乐部的成员具体步骤(必填),填写如下表单,表单链接如下:
人才储备数据库及线上礼品发放表单邀请人吴钟昊:https://ddz.red/CSDN
CSDN高校俱乐部是给大家提供技术分享交流的平台,会不定期的给大家分享CSDN方面的相关比赛以及活动或实习报名链接,希望大家一起努力加油!共同建设中南民族大学良好的技术知识分享社区。

注意:

1.社区成员不得在社区发布违反社会主义核心价值观的言论。

2.社区成员不得在社区内谈及政治敏感话题。

3.该社区为知识分享的平台,可以相互探讨、交流学习经验,尽量不在社区谈论其他无关话题。

试试用AI创作助手写篇文章吧