提交 f4eb93f7 编写于 作者: Wanchao Liang's avatar Wanchao Liang 提交者: Facebook Github Bot

Support pack_padded_sequence and pad_packed_sequence

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23249

Test Plan: Imported from OSS

Differential Revision: D16466587

Pulled By: wanchaol

fbshipit-source-id: a721da01b2da0ef90cac80b77f1285102e3b1118
上级 c384fbf4
......@@ -7809,6 +7809,28 @@ a")
f = io.BytesIO()
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
def test_script_pack_padded_sequence(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def pack_padded_pad_packed_script(x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, lengths = pad_packed_sequence(x)
return x, lengths
T, B, C = 3, 5, 7
x = torch.ones((T, B, C))
seq_lens = torch.tensor([3, 3, 2, 2, 1])
# set padding value so we can test equivalence
for b in range(B):
if seq_lens[b] < T:
x[seq_lens[b]:, b, :] = 0
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
self.assertEqual(eager_seq, script_seq)
self.assertEqual(eager_lengths, script_lengths)
def test_script_get_tracing_state(self):
def test_if_tracing(x):
if torch._C._get_tracing_state():
......
......@@ -2,11 +2,17 @@ from collections import namedtuple
import warnings
import torch
from .. import _VF
from ..._jit_internal import Optional
PackedSequence_ = namedtuple('PackedSequence',
['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
# type annotation for PackedSequence_ to make it compatible with TorchScript
PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor,
'sorted_indices': Optional[torch.Tensor],
'unsorted_indices': Optional[torch.Tensor]}
def bind(optional, fn):
if optional is None:
......@@ -219,6 +225,7 @@ def invert_permutation(permutation):
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
# type: (Tensor, Tensor, bool, bool) -> PackedSequence
r"""Packs a Tensor containing padded sequences of variable length.
:attr:`input` can be of size ``T x B x *`` where `T` is the length of the
......@@ -254,7 +261,7 @@ def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
'sequence lengths. The tracer cannot track the data flow of Python '
'values, and it will treat them as constants, likely rendering '
'the trace incorrect for any other combination of lengths.',
category=torch.jit.TracerWarning, stacklevel=2)
stacklevel=2)
lengths = torch.as_tensor(lengths, dtype=torch.int64)
if enforce_sorted:
sorted_indices = None
......@@ -265,11 +272,12 @@ def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
input = input.index_select(batch_dim, sorted_indices)
data, batch_sizes = \
torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first)
return PackedSequence(data, batch_sizes, sorted_indices)
_VF._pack_padded_sequence(input, lengths, batch_first)
return PackedSequence(data, batch_sizes, sorted_indices, None)
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
# type: (PackedSequence, bool, float, Optional[int]) -> Tuple[Tensor, Tensor]
r"""Pads a packed batch of variable length sequences.
It is an inverse operation to :func:`pack_padded_sequence`.
......@@ -310,12 +318,12 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_le
"total_length={} and max sequence length being {}"
.format(total_length, max_seq_length))
max_seq_length = total_length
padded_output, lengths = torch._C._VariableFunctions._pad_packed_sequence(
padded_output, lengths = _VF._pad_packed_sequence(
sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
if sequence.unsorted_indices is not None:
unsorted_indices = sequence.unsorted_indices
if unsorted_indices is not None:
batch_dim = 0 if batch_first else 1
return padded_output.index_select(batch_dim, sequence.unsorted_indices), \
lengths[sequence.unsorted_indices]
return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices]
return padded_output, lengths
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册