提交 8ffcbfb7 编写于 作者: Richard Zou's avatar Richard Zou 提交者: Facebook Github Bot

Add unique_ptr<NamedTensorMeta> field to TensorImpl (#21341)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21341
ghimport-source-id: 06021b06

Differential Revision: D15717907

Pulled By: zou3519

fbshipit-source-id: 48ee76cf2f11a8b092be75ecac8d5faee68ca0d9
上级 f9c4d0d7
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensor.h>
namespace at {
bool NamedTensorMeta::has_names() const {
return !std::all_of(
names.begin(), names.end(), [](const Dimname& n) {
return n.type() == NameType::WILDCARD;
});
}
}
#endif
#pragma once
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/Dimname.h>
#include <c10/core/TensorImpl.h>
namespace at {
// TensorImpl has a unique_ptr<NamedTensorMetaInterface> field. Ideally we would
// just put optional<vector<Dimname>> into TensorImpl, but the problem with that is
// that c10::Symbol isn't actually a part of the c10 lib (where TensorImpl is).
// In the long term, we should decouple c10::Symbol from aten and toss it into c10.
struct CAFFE2_API NamedTensorMeta : public c10::NamedTensorMetaInterface {
std::vector<Dimname> names;
explicit NamedTensorMeta(int64_t num_names)
: names(std::vector<Dimname>(num_names, Dimname::wildcard())) {}
explicit NamedTensorMeta(std::vector<Dimname> names)
: names(names) {}
bool has_names() const;
};
}
#endif
......@@ -14,6 +14,9 @@
#include <c10/util/Optional.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensor.h>
#endif
namespace caffe2 {
class Tensor;
......@@ -248,6 +251,14 @@ class CAFFE2_API Tensor {
/// Returns if a `Tensor` has quantized backend.
bool is_quantized() const;
#ifdef NAMEDTENSOR_ENABLED
/// Returns if a `Tensor` has any dimension names
bool is_named() const;
/// Returns a `Tensor`'s dimension names data structure
NamedTensorMeta* get_named_tensor_meta() const;
#endif
/// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
/// TensorOptions.h.
TensorOptions options() const;
......
......@@ -5,6 +5,10 @@
#include <c10/macros/Macros.h>
#include <c10/core/TensorOptions.h>
#include <ATen/core/DeprecatedTypeProperties.h>
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensor.h>
#endif
namespace at {
......@@ -1329,6 +1333,17 @@ inline bool Tensor::is_cuda() const {
return impl_->is_cuda();
}
#ifdef NAMEDTENSOR_ENABLED
inline NamedTensorMeta* Tensor::get_named_tensor_meta() const {
return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
}
inline bool Tensor::is_named() const {
auto* named_tensor_meta = get_named_tensor_meta();
return named_tensor_meta != nullptr && named_tensor_meta->has_names();
}
#endif
inline bool is_cuda(Tensor self) {
return self.is_cuda();
}
......
......@@ -14,6 +14,9 @@
#include <c10/util/Optional.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensor.h>
#endif
namespace caffe2 {
class Tensor;
......@@ -248,6 +251,14 @@ class CAFFE2_API Tensor {
/// Returns if a `Tensor` has quantized backend.
bool is_quantized() const;
#ifdef NAMEDTENSOR_ENABLED
/// Returns if a `Tensor` has any dimension names
bool is_named() const;
/// Returns a `Tensor`'s dimension names data structure
NamedTensorMeta* get_named_tensor_meta() const;
#endif
/// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
/// TensorOptions.h.
TensorOptions options() const;
......
......@@ -5,6 +5,10 @@
#include <c10/macros/Macros.h>
#include <c10/core/TensorOptions.h>
#include <ATen/core/DeprecatedTypeProperties.h>
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensor.h>
#endif
namespace at {
......@@ -88,6 +92,17 @@ inline bool Tensor::is_cuda() const {
return impl_->is_cuda();
}
#ifdef NAMEDTENSOR_ENABLED
inline NamedTensorMeta* Tensor::get_named_tensor_meta() const {
return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
}
inline bool Tensor::is_named() const {
auto* named_tensor_meta = get_named_tensor_meta();
return named_tensor_meta != nullptr && named_tensor_meta->has_names();
}
#endif
inline bool is_cuda(Tensor self) {
return self.is_cuda();
}
......
......@@ -10,6 +10,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/basic.cpp
${CMAKE_CURRENT_SOURCE_DIR}/atest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Dimname_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/NamedTensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/half_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/wrapdim_test.cpp
......
#ifdef NAMEDTENSOR_ENABLED
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/NamedTensor.h>
#include <c10/util/Exception.h>
#include <torch/csrc/utils/memory.h>
using at::Dimname;
using at::NamedTensorMeta;
using at::Symbol;
using torch::make_unique;
TEST(NamedTensorTest, defaultMetadata) {
int num_names = 4;
const auto meta = NamedTensorMeta(num_names);
for (const auto name : meta.names) {
ASSERT_EQ(name.type(), at::NameType::WILDCARD);
}
}
TEST(NamedTensorTest, isNamed) {
auto tensor = at::zeros({3, 2, 5, 7});
ASSERT_FALSE(tensor.is_named());
tensor = at::zeros({3, 2, 5, 7});
tensor.unsafeGetTensorImpl()->set_named_tensor_meta(
make_unique<NamedTensorMeta>(tensor.dim()));
ASSERT_FALSE(tensor.is_named());
tensor = at::zeros({3, 2, 5, 7});
auto N = Dimname::fromSymbol(Symbol::dimname("N"));
auto C = Dimname::fromSymbol(Symbol::dimname("C"));
auto H = Dimname::fromSymbol(Symbol::dimname("H"));
auto W = Dimname::fromSymbol(Symbol::dimname("W"));
std::vector<Dimname> names = { N, C, H, W };
tensor.unsafeGetTensorImpl()->set_named_tensor_meta(
make_unique<NamedTensorMeta>(names));
ASSERT_TRUE(tensor.is_named());
}
TEST(NamedTensorTest, attachMetadata) {
auto tensor = at::zeros({3, 2, 5, 7});
auto N = Dimname::fromSymbol(Symbol::dimname("N"));
auto C = Dimname::fromSymbol(Symbol::dimname("C"));
auto H = Dimname::fromSymbol(Symbol::dimname("H"));
auto W = Dimname::fromSymbol(Symbol::dimname("W"));
std::vector<Dimname> names = { N, C, H, W };
tensor.unsafeGetTensorImpl()->set_named_tensor_meta(
make_unique<NamedTensorMeta>(names));
const auto retrieved_meta = tensor.get_named_tensor_meta();
for (int i = 0; i < tensor.dim(); ++i) {
const auto& retrieved_name = retrieved_meta->names[i];
const auto& expected_name = names[i];
ASSERT_EQ(retrieved_name.type(), expected_name.type());
ASSERT_EQ(retrieved_name.name(), expected_name.name());
}
// Test dropping metadata
tensor.unsafeGetTensorImpl()->set_named_tensor_meta(nullptr);
ASSERT_FALSE(tensor.is_named());
}
#endif
......@@ -21,6 +21,7 @@ VALGRIND=${VALGRIND:=ON}
./xla_tensor_test
./tensor_iterator_test
./Dimname_test
./NamedTensor_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
......
......@@ -144,6 +144,10 @@ struct C10_API NonVariableTypeMode {
static void set_enabled(bool enabled);
};
#ifdef NAMEDTENSOR_ENABLED
struct C10_API NamedTensorMetaInterface {};
#endif
// NOTE [ Version Counter Sharing ]
//
// Every Tensor has a version counter. Version counters are incremented whenever the
......@@ -843,6 +847,28 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return std::move(autograd_meta_);
}
#ifdef NAMEDTENSOR_ENABLED
/**
* Set the pointer to named tensor metadata.
*/
void set_named_tensor_meta(std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta) {
#ifdef DEBUG
if (named_tensor_meta) {
TORCH_INTERNAL_ASSERT(dim() == named_tensor_meta->names.size());
}
#endif
named_tensor_meta_ = std::move(named_tensor_meta);
}
/**
* Return the pointer to named tensor metadata.
*/
c10::NamedTensorMetaInterface* named_tensor_meta() const {
return named_tensor_meta_.get();
}
#endif
// NOTE [ TensorImpl Shallow-Copying ]
//
// TensorImpl shallow-copying is used when we want to have two Variables share the same storage pointer
......@@ -1456,6 +1482,10 @@ protected:
// at a time).
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
#ifdef NAMEDTENSOR_ENABLED
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
#endif
c10::VariableVersion version_counter_;
PyObject* pyobj_ = nullptr; // weak reference
......@@ -1578,8 +1608,13 @@ protected:
// (optional) device
// miscellaneous bitfield
//
#ifdef NAMEDTENSOR_ENABLED
#define NWORDS 30
#else
#define NWORDS 29
#endif
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
sizeof(TensorImpl) == sizeof(int64_t) * 29,
sizeof(TensorImpl) == sizeof(int64_t) * NWORDS,
"You changed the size of TensorImpl on 64-bit arch."
"See Note [TensorImpl size constraints] on how to proceed.");
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册