提交 7b081e5d 编写于 作者: Will Feng's avatar Will Feng 提交者: Facebook Github Bot

Improve error message for changing tensor metadata after .data or .detach() (#23504)

Summary:
When a user tries to change metadata of a tensor created from `.data` or `.detach()`, we currently shows an error message "<function_name> is not allowed on Tensor created from .data or .detach()". However, this error message doesn't suggest what the right fix should look like. This PR improves the error message.

Closes https://github.com/pytorch/pytorch/issues/23393.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23504

Differential Revision: D16547415

Pulled By: yf225

fbshipit-source-id: 37f4a0385442e2b0966386fb14d3d938ecf4230c
上级 db1e9b1d
......@@ -90,7 +90,7 @@ int64_t SparseTensorImpl::storage_offset() const {
AT_ERROR("sparse tensors do not have storage");
}
void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe ", err_msg_tensor_metadata_change_not_allowed);
AT_ASSERT(!indices.is_variable() && !values.is_variable()); // They should be plain tensors! // TODO: change this to check `.requires_grad()` and `GradMode::is_enabled()` when Variable and Tensor are merged
TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
......
......@@ -57,7 +57,7 @@ public:
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim with
// respect to indices and values
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
TORCH_CHECK(allow_tensor_metadata_change(), "raw_resize_ is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "raw_resize_ ", err_msg_tensor_metadata_change_not_allowed);
sizes_ = size.vec();
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
......@@ -87,7 +87,7 @@ public:
// 4. When we attempt to shrink the size of any of the sparse dimensions on a non-empty sparse tensor
// (this could make some of the stored indices out-of-bound and thus unsafe).
void resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
TORCH_CHECK(allow_tensor_metadata_change(), "resize_ is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "resize_ ", err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(sparse_dim + dense_dim == static_cast<int64_t>(size.size()), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
if (nnz() > 0) {
auto alt_options_msg = "You could try the following options:\n\
......@@ -145,7 +145,7 @@ public:
// NOTE: this function will resize the sparse tensor and also set `indices` and `values` to empty.
void resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
TORCH_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ ", err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(sparse_dim + dense_dim == static_cast<int64_t>(size.size()), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
sizes_ = size.vec();
......@@ -162,13 +162,13 @@ public:
}
void set_coalesced(bool coalesced) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_coalesced is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_coalesced ", err_msg_tensor_metadata_change_not_allowed);
coalesced_ = coalesced;
}
// NOTE: this function is only used internally and not exposed to Python frontend
void set_nnz_and_narrow(int64_t new_nnz) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_nnz_and_narrow is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_nnz_and_narrow ", err_msg_tensor_metadata_change_not_allowed);
AT_ASSERT(new_nnz <= nnz());
indices_ = indices_.narrow(1, 0, new_nnz);
values_ = values_.narrow(0, 0, new_nnz);
......
......@@ -17,6 +17,16 @@ C10_DEFINE_int64(
namespace c10 {
const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
"is not allowed on a Tensor created from .data or .detach().\n"
"If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n"
"without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n"
"For example, change:\n"
" x.data.set_(y)\n"
"to:\n"
" with torch.no_grad():\n"
" x.set_(y)";
at::Tensor& TensorImpl::grad() {
if (autograd_meta()) {
return autograd_meta()->grad();
......
......@@ -692,7 +692,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* which is harder to misuse.
*/
virtual void resize_dim(int64_t ndim) {
TORCH_CHECK(allow_tensor_metadata_change(), "resize_dim is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "resize_dim ", err_msg_tensor_metadata_change_not_allowed);
sizes_.resize(ndim, 0);
strides_.resize(ndim, 0);
refresh_numel();
......@@ -708,7 +708,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* which is harder to misuse.
*/
virtual void set_size(int64_t dim, int64_t new_size) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_size is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_size ", err_msg_tensor_metadata_change_not_allowed);
sizes_.at(dim) = new_size;
refresh_numel();
refresh_contiguous();
......@@ -721,7 +721,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* which is harder to misuse.
*/
virtual void set_stride(int64_t dim, int64_t new_stride) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_stride is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_stride ", err_msg_tensor_metadata_change_not_allowed);
strides_[dim] = new_stride;
refresh_numel();
refresh_contiguous();
......@@ -735,7 +735,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* (and resizing if necessary.)
*/
virtual void set_storage_offset(int64_t storage_offset) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage_offset is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage_offset ", err_msg_tensor_metadata_change_not_allowed);
storage_offset_ = storage_offset;
}
......@@ -747,7 +747,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* this is the responsibility of the caller
*/
void set_sizes_contiguous(IntArrayRef new_size) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous ", err_msg_tensor_metadata_change_not_allowed);
auto new_dim = new_size.size();
sizes_.resize(new_dim);
......@@ -767,7 +767,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* this is the responsibility of the caller
*/
void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides ", err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(
new_size.size() == new_stride.size(),
"dimensionality of sizes (",
......@@ -1370,7 +1370,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
void set_storage(at::Storage storage) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage is not allowed on Tensor created from .data or .detach()");
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage ", err_msg_tensor_metadata_change_not_allowed);
storage_ = std::move(storage);
data_type_ = storage_.dtype();
device_opt_ = storage_.device();
......@@ -1532,6 +1532,12 @@ protected:
}
protected:
// Error message to show when the user tries to change tensor metadata on
// Tensor created from .data or .detach().
//
// See NOTE [ Metadata Change for a Detached Tensor ] for details.
static const char * const err_msg_tensor_metadata_change_not_allowed;
Storage storage_;
// This pointer points to an AutogradMeta struct that stores autograd-specific fields
// (such as grad_ / grad_fn_ / grad_accumulator_).
......
......@@ -1926,32 +1926,32 @@ class TestSparse(TestCase):
def do_test(t):
with self.assertRaisesRegex(
RuntimeError,
"raw_resize_ is not allowed on Tensor created from .data or .detach()"):
"raw_resize_ is not allowed on a Tensor created from .data or .detach()"):
t.transpose_(0, 1)
with self.assertRaisesRegex(
RuntimeError,
"resize_ is not allowed on Tensor created from .data or .detach()"):
"resize_ is not allowed on a Tensor created from .data or .detach()"):
t.resize_as_(self.sparse_empty(3, 3))
with self.assertRaisesRegex(
RuntimeError,
"resize_and_clear_ is not allowed on Tensor created from .data or .detach()"):
"resize_and_clear_ is not allowed on a Tensor created from .data or .detach()"):
t.mul_(t)
with self.assertRaisesRegex(
RuntimeError,
"set_coalesced is not allowed on Tensor created from .data or .detach()"):
"set_coalesced is not allowed on a Tensor created from .data or .detach()"):
t._coalesced_(True)
with self.assertRaisesRegex(
RuntimeError,
"set_indices_and_values_unsafe is not allowed on Tensor created from .data or .detach()"):
"set_indices_and_values_unsafe is not allowed on a Tensor created from .data or .detach()"):
a = self.sparse_tensor(torch.tensor([[0, 1, 1], [2, 0, 2]]), torch.tensor([3., 4., 5.])).data
a.add_(a)
with self.assertRaisesRegex(
RuntimeError,
"resize_and_clear_ is not allowed on Tensor created from .data or .detach()"):
"resize_and_clear_ is not allowed on a Tensor created from .data or .detach()"):
a.zero_()
with self.assertRaisesRegex(
RuntimeError,
"resize_ is not allowed on Tensor created from .data or .detach()"):
"resize_ is not allowed on a Tensor created from .data or .detach()"):
a.copy_(self.sparse_empty(3, 3))
do_test(self.sparse_empty(3, 0).data)
......
......@@ -12225,15 +12225,15 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
def do_test(t):
with self.assertRaisesRegex(
RuntimeError,
"set_sizes_contiguous is not allowed on Tensor created from .data or .detach()"):
"set_sizes_contiguous is not allowed on a Tensor created from .data or .detach()"):
t.resize_((2, 1))
with self.assertRaisesRegex(
RuntimeError,
"set_storage is not allowed on Tensor created from .data or .detach()"):
"set_storage is not allowed on a Tensor created from .data or .detach()"):
t.set_()
with self.assertRaisesRegex(
RuntimeError,
"set_storage_offset is not allowed on Tensor created from .data or .detach()"):
"set_storage_offset is not allowed on a Tensor created from .data or .detach()"):
t.set_(t.storage(), 0, t.size(), list(t.stride()))
do_test(torch.tensor([[1, 2]]).data)
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册