PyTorch 訓練 RNN 時,序列長度不固定怎麼辦?

當每個訓練數據為 sequence 的時候,我們第一反應是採用 RNN 以及其各種變體。這時新手們(我也是剛弄明白)往往會遇到這樣的問題:訓練數據 sequence 長度是變化的,難以採用 mini-batch 訓練,這時應該怎麼辦,難道只能一個 sequence 一個 sequence 地訓練嗎?針對這一問題,本文記錄 PyTorch 給出的解決方案。

需要用到的函數如下:

torch.nn.utils.rnn.pad_sequence()
torch.nn.utils.rnn.pack_padded_sequence()
torch.nn.utils.rnn.pad_packed_sequence()

pad_sequence

我們構造如下的訓練數據,其中每條訓練數據長度都不同。

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils

train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
torch.tensor([2, 2, 2, 2, 2, 2]),
torch.tensor([3, 3, 3, 3, 3]),
torch.tensor([4, 4, 4, 4]),
torch.tensor([5, 5, 5]),
torch.tensor([6, 6]),
torch.tensor([7])]

x = rnn_utils.pad_sequence(train_x, batch_first=True)

x 將變成:

tensor([[1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 0],
[3, 3, 3, 3, 3, 0, 0],
[4, 4, 4, 4, 0, 0, 0],
[5, 5, 5, 0, 0, 0, 0],
[6, 6, 0, 0, 0, 0, 0],
[7, 0, 0, 0, 0, 0, 0]])

我們發現,這個函數會把長度小於最大長度的 sequences 用 0 填充,並且把 list 中所有的元素拼成一個 tensor。這樣做的主要目的是為了讓 DataLoader 可以返回 batch,因為 batch 是一個高維的 tensor,其中每個元素的數據必須長度相同。

為了證明這一點,我們完整地寫一個數據類,用 dataloader 按 batch 的形式讀取數據,代碼如下:

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data

train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
torch.tensor([2, 2, 2, 2, 2, 2]),
torch.tensor([3, 3, 3, 3, 3]),
torch.tensor([4, 4, 4, 4]),
torch.tensor([5, 5, 5]),
torch.tensor([6, 6]),
torch.tensor([7])]

x = rnn_utils.pad_sequence(train_x, batch_first=True)

class MyData(data.Dataset):
def __init__(self, data_seq):
self.data_seq = data_seq

def __len__(self):
return len(self.data_seq)

def __getitem__(self, idx):
return self.data_seq[idx]

if __name__==__main__:
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=2, shuffle=True)
batch_x = iter(data_loader).next()
print(END)

我們將會收到如下報錯:

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension
0. Got 3 and 7 in dimension 1 at
/pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333

報錯的原因是,不同的數據長度不同,無法組成一個 batch tensor。

DataLoader中有個參數 collate_fn,專門用來把 Dataset 類的返回值拼接成 tensor,我們不設置的時候,會調用 default 的函數,這次我們的訓練數據長度不一,default 函數就 hold 不住了,因此我們要自定義一個 collate_fn,並在 DataLoader 中設置這個參數,再運行就不會報錯了(注意代碼中對 data 先按照長度降序排列了一下,後面會講到原因)。

def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data

if __name__==__main__:
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x = iter(data_loader).next()
print(END)

運行結果如下:

batch_x
Out[2]:
tensor([[1, 1, 1, 1, 1, 1, 1],
[3, 3, 3, 3, 3, 0, 0],
[6, 6, 0, 0, 0, 0, 0]])

正是我們想要的。

pack_padded_sequence

我們通過 pad_sequence 得到了 padded_sequence,那麼直接扔進 RNN 訓練不就完了嗎?為啥還要用 pack_padded_sequence?這個 pack 又是什麼意思呢?

我們回憶一下 RNN 是如何訓練的,首先考慮單個訓練數據,也就是batch_size=1 的情況:每次網路吃進一個 time step 的數據+該數據對應的 hidden state,然後輸出,再繼續吃進去第二個 time step 的數據 + hidden state,再輸出,以此類推;如果換成 mini-batch 的訓練模式則是:網路每次吃進去一組同樣 time step 的數據,也就是mini-batch 中所有 sequence 中相同下標的數據,加上它們對應的 hidden state,獲得一個 mini-batch 的輸出,然後再移到下一個 time step,再讀入 mini-batch 中所有該 time step 的數據,再輸出……

因此,以上面 pad_sequence的輸出為例,數據將會按照如圖所示的方式讀取:

網路讀取數據的順序是:[1, 3, 6],[1, 3, 6],[1, 3, 0],[1, 3, 0],[1, 3, 0],[1, 0, 0],[1, 0, 0]。而該 mini-batch 中的 0 是沒有意義的 padding,只是為了用來讓它和最長的數據對齊而已,顯然這種做法浪費了大量計算資源。因此,我們將用到 pack_padded_sequence 。即,不光要 padd,還要 pack。

pack_padded_sequence 有三個參數:input, lengths, batch_firstinput 是上一步加過 padding 的數據,lengths 是各個 sequence 的實際長度,batch_first是數據各個 dimension 按照 [batch_size, sequence_length, data_dim]順序排列。

上面例子中,batch_x 為:

batch_x
Out[2]:
tensor([[1, 1, 1, 1, 1, 1, 1],
[3, 3, 3, 3, 3, 0, 0],
[6, 6, 0, 0, 0, 0, 0]])

因此應該設置 lengths=[7, 5, 2]

rnn_utils.pack_padded_sequence(batch_x, [7,5,2], batch_first=True)
Out[3]: PackedSequence(
data=tensor([1., 3., 6., 1., 3., 6., 1., 3., 1., 3., 1., 3., 1., 1.]),
batch_sizes=tensor([3, 3, 2, 2, 2, 1, 1]))

我們發現,它的輸出有兩部分,分別是 databatch_sizes,第一部分為原來的數據按照 time step 重新排列,而 padding 的部分,直接空過了。batch_sizes則是每次實際讀入的數據量,也就是說,RNN 把一個 mini-batch sequence 又重新劃分為了很多小的 batch,每個小 batch 為所有 sequence 在當前 time step 對應的值,如果某 sequence 在當前 time step 已經沒有值了,那麼,就不再讀入填充的 0,而是降低 batch_sizebatch_size相當於是對訓練數據的重新劃分。這也是為什麼前面在 collate_fn中我們要對 mini-batch 中的 sequence 按照長度降序排列,是為了方便我們取每個 time step 的batch,防止中間夾雜著 padding。 而每個 mini-batch 中 sequence 的真實 length 又如何獲得呢?這就要重新修改 collate_fn了,我們在其中加入data_length=[len(sq) for sq in data] 修改後的代碼如下:

def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data, data_length

if __name__==__main__:
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x, batch_x_len = iter(data_loader).next()
batch_x_pack = rnn_utils.pack_padded_sequence(batch_x,
batch_x_len, batch_first=True)

pad_packed_sequence

一看名字就知道,這個函數和前面的函數是一對。有點像西遊記里的奔波兒灞和灞波兒奔。

上文的例子中,我們為了直觀,沒有考慮到 RNN 對數據維度的要求,因此在這裡我們要重新改寫 collate_fn使其返回的數據符合 [batch, sequence_len, input_size]的格式(我們設置網路為 batch_first的模式,更符合習慣)。在例子中,每個 sequence 的元素維度都是1,因此只需要在 tensor 末尾加一維就好了,即對返回的數據 unsqueeze(-1) 一下(也可以在資料庫的類中,對 _getitem_的返回值 unsqueeze)。

def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data.unsqueeze(-1), data_length

修改後,batch_xbatch_x_pack分別為:

batch_x
Out[2]:
tensor([[[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]],
[[3.],
[3.],
[3.],
[3.],
[3.],
[0.],
[0.]],
[[6.],
[6.],
[0.],
[0.],
[0.],
[0.],
[0.]]])
batch_x_pack
Out[3]:
PackedSequence(data=tensor([
[1.],
[3.],
[6.],
[1.],
[3.],
[6.],
[1.],
[3.],
[1.],
[3.],
[1.],
[3.],
[1.],
[1.]]), batch_sizes=tensor([3, 3, 2, 2, 2, 1, 1]))

符合我們的預期。

接下來,我們隨機初始化 hidden state 和 cell state (維度為:num_layers * num_directions, batch, hidden_size), 和batch_x_pack一起送入LSTM中。

if __name__==__main__:
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x, batch_x_len = iter(data_loader).next()
batch_x_pack = rnn_utils.pack_padded_sequence(batch_x,
batch_x_len, batch_first=True)

net = nn.LSTM(1, 10, 2, batch_first=True)
h0 = torch.rand(2, 3, 10)
c0 = torch.rand(2, 3, 10)
out, (h1, c1) = net(batch_x_pack, (h0, c0))
print(END)

其中 LSTM 輸入為 1 維,hidden size 為 10 ,總共兩層。經過一次前向傳播,我們得到 outoutbatch_x_pack一樣,分為兩部分: databatch_sizes。觀察一下它這兩部分:

out.data.shape
Out[5]: torch.Size([14, 10])
batch_x_pack.data.shape
Out[6]: torch.Size([14, 1])
out.batch_sizes
Out[7]: tensor([3, 3, 2, 2, 2, 1, 1])
batch_x_pack.batch_sizes
Out[8]: tensor([3, 3, 2, 2, 2, 1, 1])

輸入的 mini-batch 中,統計所有 time step 共有 14 個非零的數據,而 LSTM 的 hidden unit 有10維,故 out.data.shape torch.Size([14, 10])。而out.batch_sizes則和 batch_x_pack.batch_sizes相同,都是 tensor([3, 3, 2, 2, 2, 1, 1])

pad_packed_sequence 執行的是 pack_padded_sequence 的逆操作,執行下面的代碼,觀察輸出。

out, (h1, c1) = net(batch_x_pack, (h0, c0))
out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)
out_pad.shape
Out[2]: torch.Size([3, 7, 10])
out.data.shape
Out[3]: torch.Size([14, 10])
out_len
Out[4]: tensor([7, 5, 2])

我們發現,經過這樣的操作後out_pad 形狀變成了[3, 7, 10],彷彿我們直接輸入加了padding 的 mini-batch ,mini-batch 中有 3 個 sequence,每個 sequence 有 7 個 time step,每個 time step 數據從輸入的 1 維,映射成 LSTM 的 10 維,此外它還輸出了 out_len,為 [7, 5, 2],即每個 sequence 的真實長度。 為了放心,我們再看一下out_pad[1]是什麼:

out_pad[1].shape
Out[11]: torch.Size([7, 10])
out_pad[1]
Out[12]:
tensor([[ 0.0027, -0.0135, 0.1366, -0.0420, 0.3269, 0.0726, -0.0872, -0.0409,
0.1267, 0.2546],
[-0.0365, -0.0574, 0.0436, -0.0346, 0.2652, -0.0088, -0.0881, -0.0700,
0.1753, 0.2102],
[-0.0557, -0.0865, -0.0048, -0.0317, 0.1738, -0.0366, -0.0858, -0.0805,
0.1873, 0.1898],
[-0.0684, -0.1015, -0.0261, -0.0301, 0.1045, -0.0485, -0.0828, -0.0843,
0.1961, 0.1764],
[-0.0769, -0.1085, -0.0354, -0.0294, 0.0567, -0.0542, -0.0807, -0.0857,
0.2019, 0.1671],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]], grad_fn=<SelectBackward>)

下標為 1 的 sequence 真實長度是 5 ,第 6 、7 個 time step 是填充的 0,因此它對應的輸出第 6 、7 行都是 0,符合我們的預期。

總結

torch.nn.utils.rnn.pad_sequence()
torch.nn.utils.rnn.pack_padded_sequence()
torch.nn.utils.rnn.pad_packed_sequence()

上面三個函數相互配合,可以在 sequence 長度變化時,成批讀入數據,訓練 RNN。第一個函數用於給 mini-batch 中的數據加 padding,讓 mini-batch 中所有 sequence 的長度等於該 mini-batch 中最長的那個 sequence 的長度。

第二、三個函數,用於提高效率,避免 LSTM 前向傳播時,把加入在訓練數據中的 padding 考慮進去。因此第二、三個函數理論上可以不用,但為了提高效率最好還是用。

除此之外,本文還介紹了 DataLoadercollate_fn參數,用於把 Dataset類的 __getitem__ 方法的返回的 batchsize 個值拼接成一個 tensor。

全部代碼如下:

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data

train_x = [torch.Tensor([1, 1, 1, 1, 1, 1, 1]),
torch.Tensor([2, 2, 2, 2, 2, 2]),
torch.Tensor([3, 3, 3, 3, 3]),
torch.Tensor([4, 4, 4, 4]),
torch.Tensor([5, 5, 5]),
torch.Tensor([6, 6]),
torch.Tensor([7])
]

x = rnn_utils.pad_sequence(train_x, batch_first=True)

class MyData(data.Dataset):
def __init__(self, data_seq):
self.data_seq = data_seq

def __len__(self):
return len(self.data_seq)

def __getitem__(self, idx):
return self.data_seq[idx]

def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True)
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data.unsqueeze(-1), data_length

if __name__==__main__:
data = MyData(train_x)
data_loader = DataLoader(data, batch_size=3, shuffle=True,
collate_fn=collate_fn)
batch_x, batch_x_len = iter(data_loader).next()
batch_x_pack = rnn_utils.pack_padded_sequence(batch_x,
batch_x_len, batch_first=True)

net = nn.LSTM(1, 10, 2, batch_first=True)
h0 = torch.rand(2, 3, 10)
c0 = torch.rand(2, 3, 10)
out, (h1, c1) = net(batch_x_pack, (h0, c0))
out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)
print(END)

推薦閱讀:

TAG:PyTorch | RNN | 深度學習(DeepLearning) |