需要用到的模块和函数
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.utils.rnn import pack_padded_sequence
有三句话,我想将这三句话作为一个batch,输入到RNN中进行训练
a = [[1], [2], [3]]
b = [[4], [5]]
c = [[6]]
我定义一个RNN (为了简单,我们假设输入RNN的词向量的维度是 1,输出维度是3, 为了符合我这只菜鸡的直觉,我把batch放在第一维,设置batch_size=True)
rnn = nn.RNN(1, 3, batch_first=True)
如果我只输入一个样本
input = torch.FloatTensor([a]) # 这个地方注意是FloatTensor而不是Tensor,如果采用Tensor会报错,类型不匹配
rnn(input)
但是一次只训练一个样本很慢,不能并行训练,所以我要将a, b, c这三句话封装进一个batch,传给RNN
input = torch.FloatTensor([a, b, c]) # 运行时就会发现,这一行代码报错了
rnn(input)
错误信息:ValueError: expected sequence of length 3 at dim 1 (got 2)
啥意思呢? 意思是:期待一个长度为3的序列,你却给了它长度为2的序列,说的就是 b 这句话!错误原因:一个二维的Tensor,就是一个矩阵,矩阵应该长这样:
correct_matrix=[
[[1], [2], [3]],
[[4], [5], [0]],
[[6], [0], [0]]
]
而不是长这样:
error_matrix=[
[[1], [2], [3]],
[[4], [5]]
[[6]]
]
简单来说,就是矩阵的每一行的长度应该是相同的!!!
而我们刚刚的input的每一行(也就是每一句话)长度是不相同的!!!解决方案: 可以通过填充的方式使所有序列长度都相同!即通过补0的方式将error_matrix -> correct_matrix
实现办法:torch.nn.utils.rnn.pad_sequence
padded_sequence = pad_sequence([torch.FloatTensor(a), torch.FloatTensor(b), torch.FloatTensor(c)], batch_first=True)
print(padded_sequence)
rnn(padded_sequence) # 搞定,现在就是并行计算了
但是现在还是有问题,因为补0了之后,一方面浪费了内存,另一方面使得原本训练一个长度的样本变成了训练三个长度的样本,浪费计算资源,而且可能给模型训练的结果造成影响
那有没有办法呢?
既然Tensor要求每一行的长度都相同,而我们的每个句子长度不相同,那我们就不传一个Tensor作为参数呗
不传Tensor传什么呢?传PackedSequence!
packed_sequence = pack_sequence([torch.FloatTensor(i) for i in [a, b, c]]) # packed_sequence是PackedSequence的实例
print(packed_sequence)
但是pack_sequence这玩意返回的PackedSequence是batch_first=False的,这就让我很不爽,这意味着我的RNN必须是batch_first=False,所以不得不重新定义网络
rnn = nn.RNN(1, 3, batch_first=False)
print(rnn(packed_sequence)) # 搞定!
最后,我们还可以将pad_sequence得到的结果与pack_sequence得到的结果进行相互转换。
packed_padded_sequence = pack_padded_sequence(padded_sequence, [3, 2, 1]) # 但是你传入一个list告诉它每个句子的长度
print(packed_padded_sequence)
padded_packed_sequence = pad_packed_sequence(packed_sequence)
print(padded_packed_sequence)
我的表达能力不好,建议参考我立波师兄(虽然他不认识我哈哈哈)的教程。
(
忆臻:pytorch中如何处理RNN输入变长序列paddingzhuanlan.zhihu.com
)
很少写博客,因为总觉得自己太菜,写的东西拿不出手。
好好努力吧!!!