提交 437a8b3e 编写于 作者: Richard Zou's avatar Richard Zou 提交者: Facebook Github Bot

Named inference rule for copy_

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23229

Test Plan: Imported from OSS

Differential Revision: D16494413

Pulled By: zou3519

fbshipit-source-id: 4acb85e5a4ad09bf5f7cbb84cc8d4ceac0cd9967
上级 16da355b
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/quantized/Copy.h> #include <ATen/native/quantized/Copy.h>
#include <ATen/quantized/Quantizer.h> #include <ATen/quantized/Quantizer.h>
#ifdef BUILD_NAMEDTENSOR
#include <ATen/NamedTensorUtils.h>
#endif
namespace { namespace {
...@@ -69,6 +72,14 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) { ...@@ -69,6 +72,14 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
} }
} }
}); });
#ifdef BUILD_NAMEDTENSOR
auto outnames = unify_from_right(self.names(), src.names());
if (outnames.has_value()) {
at::internal_set_names_inplace(self, *outnames);
} else {
at::internal_set_names_inplace(self, nullopt);
}
#endif
} }
// Devices directly supported by this copy implementation. Other device types // Devices directly supported by this copy implementation. Other device types
......
...@@ -508,6 +508,7 @@ ...@@ -508,6 +508,7 @@
- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) - func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
variants: method variants: method
device_guard: False device_guard: False
named_guard: False
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor - func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
dispatch: {} dispatch: {}
......
...@@ -74,6 +74,18 @@ class TestNamedTensor(TestCase): ...@@ -74,6 +74,18 @@ class TestNamedTensor(TestCase):
def test_empty(self): def test_empty(self):
self._test_factory(torch.empty, 'cpu') self._test_factory(torch.empty, 'cpu')
def test_copy_transpose(self):
# This type of copy is special-cased and therefore needs its own test
def _test(self_names, other_names, expected_names):
x = torch.empty(2, 5, names=self_names)
y = torch.empty(5, 2).t().set_names_(other_names)
x.copy_(y)
self.assertEqual(x.names, expected_names)
_test(('N', 'C'), ('N', 'C'), ('N', 'C'))
_test(('N', None), ('N', 'C'), ('N', 'C'))
_test(None, ('N', 'C'), ('N', 'C'))
def test_set_names_(self): def test_set_names_(self):
tensor = torch.empty(1, 1, names=('N', 'C')) tensor = torch.empty(1, 1, names=('N', 'C'))
self.assertEqual(tensor.set_names_(None).names, (None, None)) self.assertEqual(tensor.set_names_(None).names, (None, None))
...@@ -223,6 +235,7 @@ class TestNamedTensor(TestCase): ...@@ -223,6 +235,7 @@ class TestNamedTensor(TestCase):
tests = [ tests = [
fn_method_and_inplace('mul'), fn_method_and_inplace('mul'),
method('copy_'),
] ]
tests = flatten(tests) tests = flatten(tests)
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册