提交 fd61cc9e 编写于 作者: Pavel Belevich's avatar Pavel Belevich 提交者: Facebook Github Bot

Moved at::assert_no_internal_overlap to TensorIterator

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22917

Differential Revision: D16521429

Pulled By: pbelevich

fbshipit-source-id: 80ae583c6486d6948431b79e1452902bdf2cfbc3
上级 4b78ce1b
......@@ -23,17 +23,15 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
return MemOverlap::TOO_HARD;
}
void assert_no_internal_overlap(const Tensor& t, const std::string& op) {
assert_no_internal_overlap(t.unsafeGetTensorImpl(), op);
void assert_no_internal_overlap(const Tensor& t) {
assert_no_internal_overlap(t.unsafeGetTensorImpl());
}
void assert_no_internal_overlap(TensorImpl* t, const std::string& op) {
if (has_internal_overlap(t) == MemOverlap::YES) {
AT_ERROR(
op, ": unsupported operation: more than one element of the written-to "
"tensor refers to a single memory location. Please clone() the tensor "
"before calling ", op);
}
void assert_no_internal_overlap(TensorImpl* t) {
TORCH_CHECK(has_internal_overlap(t) != MemOverlap::YES,
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation.");
}
}
......@@ -16,7 +16,7 @@ enum class MemOverlap { NO, YES, TOO_HARD };
CAFFE2_API MemOverlap has_internal_overlap(const Tensor& t);
CAFFE2_API MemOverlap has_internal_overlap(TensorImpl* t);
CAFFE2_API void assert_no_internal_overlap(const Tensor& t, const std::string& op);
CAFFE2_API void assert_no_internal_overlap(TensorImpl* t, const std::string& op);
CAFFE2_API void assert_no_internal_overlap(const Tensor& t);
CAFFE2_API void assert_no_internal_overlap(TensorImpl* t);
}
......@@ -25,8 +25,8 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
} else if (self.is_sparse()) {
AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
}
at::assert_no_internal_overlap(result, "add");
auto iter = TensorIterator::binary_op(result, self, other);
auto iter = TensorIterator::binary_op(result, self, other,
/*check_internal_overlap=*/true);
add_stub(iter.device_type(), iter, alpha);
return result;
}
......@@ -54,8 +54,8 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
}
return at::_sparse_div_zerodim_out(result, self, other);
}
at::assert_no_internal_overlap(result, "div");
auto iter = TensorIterator::binary_op(result, self, other);
auto iter = TensorIterator::binary_op(result, self, other,
/*check_internal_overlap=*/true);
div_stub(iter.device_type(), iter);
return result;
}
......@@ -79,8 +79,8 @@ Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (self.is_sparse() || other.is_sparse()) {
return at::_sparse_mul_out(result, self, other);
}
at::assert_no_internal_overlap(result, "mul");
auto iter = TensorIterator::binary_op(result, self, other);
auto iter = TensorIterator::binary_op(result, self, other,
/*check_internal_overlap=*/true);
mul_stub(iter.device_type(), iter);
return result;
}
......@@ -125,8 +125,8 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
} else if (self.is_sparse()) {
AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead.");
}
at::assert_no_internal_overlap(result, "sub");
auto iter = TensorIterator::binary_op(result, self, other);
auto iter = TensorIterator::binary_op(result, self, other,
/*check_internal_overlap=*/true);
sub_stub(iter.device_type(), iter, alpha);
return result;
}
......
......@@ -534,9 +534,14 @@ void TensorIterator::select_all_keeping_dim(int start_dim, IntArrayRef indices)
}
}
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) {
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
const Tensor& b, bool check_internal_overlap) {
auto iter = TensorIterator();
iter.add_output(out);
if (check_internal_overlap) {
iter.check_and_add_output(out);
} else {
iter.add_output(out);
}
iter.add_input(a);
iter.add_input(b);
iter.allow_cpu_scalars_ = true;
......@@ -544,9 +549,14 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Ten
return iter;
}
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a,
bool check_internal_overlap) {
auto iter = TensorIterator();
iter.add_output(out);
if (check_internal_overlap) {
iter.check_and_add_output(out);
} else {
iter.add_output(out);
}
iter.add_input(a);
iter.num_outputs_ = 1;
iter.build();
......
......@@ -6,6 +6,7 @@
#include <ATen/detail/ScalarTypeConversions.h>
#include <bitset>
#include <c10/util/Optional.h>
#include <ATen/MemoryOverlap.h>
#ifdef BUILD_NAMEDTENSOR
#include <ATen/NamedTensorUtils.h>
#endif
......@@ -142,8 +143,10 @@ struct CAFFE2_API TensorIterator {
void foreach_reduced_elt(const loop_subiter_t& loop, bool parallelize=true);
static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b);
static TensorIterator unary_op(Tensor& out, const Tensor& a);
static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b,
bool check_internal_overlap = false);
static TensorIterator unary_op(Tensor& out, const Tensor& a,
bool check_internal_overlap = false);
static TensorIterator nullary_op(Tensor& out);
static TensorIterator reduce_op(Tensor& out, const Tensor& a);
static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);
......@@ -261,6 +264,11 @@ struct CAFFE2_API TensorIterator {
num_outputs_++;
}
void check_and_add_output(const Tensor& output) {
assert_no_internal_overlap(output);
add_output(output);
}
void add_output(const Tensor& input, Device device, ScalarType dtype) {
operands_.emplace_back(input, device, dtype);
num_outputs_++;
......@@ -312,7 +320,6 @@ protected:
bool promote_gpu_output_dtypes_ = false;
bool final_output_ = true;
};
/// A container-like struct that acts as if it contains splits of a
/// TensorIterator that can use 32-bit indexing. Taken together the splits cover
/// the original TensorIterator.
......
......@@ -47,8 +47,8 @@ Tensor& bitwise_not_(Tensor& self) {
Tensor& bitwise_not_out(Tensor& result, const Tensor& self) {
checkBackend("bitwise_not", result, self.type().backend());
assert_no_internal_overlap(result, "bitwise_not");
auto iter = TensorIterator::unary_op(result, self);
auto iter = TensorIterator::unary_op(result, self,
/*check_internal_overlap=*/true);
bitwise_not_stub(iter.device_type(), iter);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(result, self);
......@@ -161,8 +161,8 @@ static void propagate_names_if_namedtensor_enabled(Tensor& result, const Tensor&
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
checkBackend(#op, result, Backend::CPU); \
assert_no_internal_overlap(result, #op); \
auto iter = TensorIterator::unary_op(result, self); \
auto iter = TensorIterator::unary_op(result, self, \
/*check_internal_overlap=*/true); \
op##_stub(iter.device_type(), iter); \
return result; \
}
......
......@@ -196,8 +196,8 @@ static void propagate_names_if_named_tensor_enabled(THCTensor* result, THCTensor
}; \
\
void THCTensor_(NAME)(THCState* state, THCTensor* self_, THCTensor* src) { \
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); \
at::assert_no_internal_overlap(self_, #NAME); \
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); \
at::assert_no_internal_overlap(self_); \
if (self_ == src) { \
if (!THC_pointwiseApply1<scalar_t>(state, self_, Tensor_##NAME##_##REAL##_Op())) { \
THArgCheck(false, 2, CUTORCH_DIM_WARNING); \
......
......@@ -12197,6 +12197,14 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
def test_cosh_unary_mem_overlap(self):
self.unary_check_mem_overlap(lambda t: t.cosh_())
@unittest.expectedFailure
def test_lerp_mem_overlap(self):
start = torch.randn(1, device=device).expand(3, 3)
end = torch.randn(3, 3, device=device)
weight = torch.randn(3, 3, device=device)
with self.assertRaisesRegex(RuntimeError, 'single memory location'):
start.lerp_(end, weight)
@unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
def test_reverse_binary_ops_multiple_device(self):
self.assertEqual(2 + torch.tensor(3), 2 + torch.tensor(3).to("cuda:1")) # __radd__
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册