提交 c779eff5 编写于 作者: Wanchao Liang's avatar Wanchao Liang 提交者: Facebook Github Bot

support torch.as_tensor in script

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23247

Test Plan: Imported from OSS

Differential Revision: D16466590

Pulled By: wanchaol

fbshipit-source-id: cf52721eacd177d9040564790382db13a9fcc2fe
上级 3a568c9a
......@@ -199,6 +199,7 @@ _(aten, arange) \
_(aten, argmax) \
_(aten, argmin) \
_(aten, as_strided) \
_(aten, as_tensor) \
_(aten, asin) \
_(aten, atan) \
_(aten, atan2) \
......
......@@ -5681,58 +5681,49 @@ a")
foo()
@suppress_warnings
def test_torch_tensor_empty_list(self):
def test_torch_tensor_as_tensor_empty_list(self):
tensor_template = dedent('''
def func():
return torch.tensor(torch.jit.annotate(List[int], []))
cu = torch.jit.script(func)
t1 = cu()
t2 = func()
# torchscript returns int tensor, python returns float tensor
self.assertNotEqual(t1.dtype, t2.dtype)
empty_list = torch.jit.annotate(List[int], [])
ten1 = torch.{tensor_op}({input})
return ten1
''')
ops = ['tensor', 'as_tensor']
inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]']
def func():
li = torch.jit.annotate(List[int], [])
return torch.tensor([li, li])
for op in ops:
for inp in inputs:
code = tensor_template.format(tensor_op=op, input=inp)
scope = {}
exec(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
t1 = cu.func()
t2 = scope['func']()
if inp == 'empty_list':
# torchscript returns int tensor, python returns float tensor
self.assertNotEqual(t1.dtype, t2.dtype)
self.checkScript(func, ())
self.assertEqual(t1, t2)
self.assertEqual(t1.device, t2.device)
def test_tensor_as_tensor_shape_prop(self):
tensor_template = dedent('''
def func():
li = torch.jit.annotate(List[int], [])
return torch.tensor([[[li]]])
self.checkScript(func, ())
def test_tensor_shape_prop(self):
def func1():
return torch.tensor([1])
def func2():
return torch.tensor([False])
def func3():
return torch.tensor([2.5])
def func4():
return torch.tensor(0.5)
def func5():
return torch.tensor(1)
def func6():
return torch.tensor(False)
def func7():
return torch.tensor([[1]])
list_input = [func1, func2, func3, func4, func5, func6, func7]
return torch.{tensor_op}({input})
''')
ops = ['tensor', 'as_tensor']
inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]']
expected_shape = ["Long(*)", ("Bool(*)"), "Double(*)", "Double()", "Long()", "Bool()", "Long(*, *)"]
for fn, expect in zip(list_input, expected_shape):
self.checkScript(fn, ())
g = torch.jit.script(fn)
torch._C._jit_pass_complete_shape_analysis(g.graph, (), False)
FileCheck().check(expect).check("aten::tensor").run(g.graph)
for op in ops:
for inp, expect in zip(inputs, expected_shape):
code = tensor_template.format(tensor_op=op, input=inp)
scope = {}
exec(code, globals(), scope)
self.checkScript(code, ())
cu = torch.jit.CompilationUnit(code)
torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False)
FileCheck().check(expect).check("aten::{tensor_op}".format(tensor_op=op)).run(cu.func.graph)
@torch.jit.script
def test_dtype(inp_dtype):
......@@ -5740,11 +5731,19 @@ a")
a = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
return a, torch.tensor(1.0, dtype=inp_dtype) # noqa T484
test_dtype(5)
g = test_dtype.graph_for(5)
# first should have type set second should not
FileCheck().check("Float() = aten::tensor").check("Tensor = aten::tensor").run(g)
@torch.jit.script
def test_as_tensor_tensor_input(input):
a = torch.as_tensor(input, dtype=input.dtype)
return a, torch.as_tensor(input, dtype=torch.float)
g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4))
FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *) = aten::as_tensor").run(g)
def test_tensor_requires_grad(self):
@torch.jit.script
def test(b):
......@@ -5775,11 +5774,12 @@ a")
b_script.backward()
self.assertEqual(a.grad, a_script.grad)
def test_torch_tensor(self):
template = dedent('''
def test_torch_tensor_as_tensor(self):
tensor_template = dedent('''
def func():
li = {list_create}
return torch.tensor(li {options})
ten1 = torch.{tensor_op}(li {options})
return ten1
''')
lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
......@@ -5789,28 +5789,37 @@ a")
", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
", dtype=torch.int", ", dtype=torch.long"]
ops = ['tensor', 'as_tensor']
devices = ['', ", device='cpu'"]
if RUN_CUDA:
devices.append(", device='cuda'")
option_pairs = [dtype + device for dtype in dtypes for device in devices]
for li in lists:
for option in option_pairs:
# tensor from empty list is type float in python and annotated type in torchscript
if "annotate" in li and "dtype" not in option:
continue
code = template.format(list_create=li, options=option)
scope = {}
exec(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
t1 = cu.func()
t2 = scope['func']()
if t1.dtype == torch.float16: # equality NYI for half tensor
self.assertTrue(str(t1) == str(t2))
else:
self.assertEqual(t1, t2)
self.assertEqual(t1.dtype, t2.dtype)
self.assertEqual(t1.device, t2.device)
for op in ops:
for li in lists:
for option in option_pairs:
# tensor from empty list is type float in python and annotated type in torchscript
if "annotate" in li and "dtype" not in option:
continue
code = tensor_template.format(list_create=li, tensor_op=op, options=option)
scope = {}
exec(code, globals(), scope)
cu = torch.jit.CompilationUnit(code)
t1 = cu.func()
t2 = scope['func']()
if t1.dtype == torch.float16: # equality NYI for half tensor
self.assertTrue(str(t1) == str(t2))
else:
self.assertEqual(t1, t2)
self.assertEqual(t1.dtype, t2.dtype)
self.assertEqual(t1.device, t2.device)
def test_as_tensor_tensor_input(input):
# type: (Tensor) -> Tuple[Tensor, Tensor]
return torch.as_tensor(input, dtype=torch.float), torch.as_tensor(input, dtype=torch.int32)
inp = torch.randn(3, 4)
self.checkScript(test_as_tensor_tensor_input, (inp,))
# adapted from test in test_torch
def test_tensor_to(self):
......
......@@ -527,7 +527,14 @@ class ShapePropagator {
}
return;
}
case aten::tensor: {
case aten::tensor:
case aten::as_tensor: {
// as_tensor has an overloaded schema and can either have a tensor or
// a list as the first input, if the input is a tensor, we delegate
// the shape propagation in PropagateTensorShapeOnNode
if (node->inputs().at(0)->type()->isSubtypeOf(TensorType::get())) {
break;
}
return propagateTorchTensorShape(node);
}
case prim::TupleConstruct: {
......@@ -1544,6 +1551,33 @@ class ShapePropagator {
node->matches(
"aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
return reshape_prop(node, attr::size, tensor_types);
} else if (node->matches("aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) {
TypePtr input_type = node->inputs().at(0)->type();
if (auto type = input_type->cast<DimensionedTensorType>()) {
at::ScalarType default_type = type->scalarType();
c10::Device default_device = type->device();
if (auto dtype_index = node->schema().argumentIndexWithName("dtype")) {
auto inp = toIValue(node->inputs().at(*dtype_index));
if (inp == c10::nullopt) {
return nullptr;
}
if (!inp->isNone()) {
default_type = inp->toScalarType();
}
}
if (auto device_index = node->schema().argumentIndexWithName("device")) {
auto inp = toIValue(node->inputs().at(*device_index));
if (inp == c10::nullopt) {
return nullptr;
}
if (!inp->isNone()) {
default_device = inp->toDevice();
}
}
node->output()->setType(
DimensionedTensorType::create(default_type, default_device, type->dim()));
}
return nullptr;
} else if (node->matches(
"aten::reshape(Tensor self, int[] shape) -> Tensor")) {
return reshape_prop(node, attr::shape, tensor_types);
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册