如何使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵?

Why____not 2022-06-21 01:10:03

我尝试了下面的方案:

import numpy as np
from transformers import BertTokenizer,BertForMaskedLM
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers, vmap, vjp, jvp, jacrev
device = 'cuda:2'
torch.cuda.empty_cache()


model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)

net = bert_model.to(device)
fnet, params, buffers = make_functional_with_buffers(net)

def fnet_single(params,x,y):
    result = fnet(params, buffers, x.unsqueeze(0).unsqueeze(0),y.unsqueeze(0).unsqueeze(0))['logits']
    return result.squeeze(0).squeeze(0)

text = u'大肠杆菌是人和许多动物肠道中最主要的一种细菌'
inputs = tokenizer.encode_plus(text)

segment_ids = inputs['token_type_ids']
token_ids = inputs['input_ids']
length = len(token_ids) - 2


batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
batch_segment_ids = torch.zeros_like(batch_token_ids).to(device)

for i in range(length):
    if i > 0:
        batch_token_ids[2 * i - 1, i] = 103
        batch_token_ids[2 * i - 1, i + 1] = 103
    batch_token_ids[2 * i, i + 1] = 103
threshold = 100
word_token_ids = [[token_ids[1]]]
for i in range(1, length):
    x,y = batch_token_ids[2 * i],batch_segment_ids[2*i]
    jacobian1 = jacrev(fnet_single,argnums=1)(params,x,y)
    x,y = batch_token_ids[2 * i - 1],batch_segment_ids[2*i-1]
    jacobian2 = jacrev(fnet_single,argnums=1)(params,x,y)
    print(jacobian1,end='-----------------jacobian1-----------------\n')  
    print(jacobian2,end='-----------------jacobian2-----------------\n') 

但是会有以下的报错:

“Traceback (most recent call last):
  File "study_jacrev.py", line 49, in <module>
    batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients”

感觉这应该是由于BertForMaskedLM模型输入的是token的id,全是整数导致的。那么想要计算输入处的雅克比矩阵应该如何进行呢?

...全文
82 回复 打赏 收藏 举报
写回复
回复
切换为时间正序
请发表友善的回复…
发表回复
发帖
人工智能技术

4207

社区成员

专题开发/技术/项目 人工智能技术
社区管理员
  • community_35
  • Ashley0001
  • 大龙剑神
加入社区
帖子事件
创建了帖子
2022-06-21 01:10
社区公告
暂无公告