提交 4727685e 编写于 作者: Richard Zou's avatar Richard Zou 提交者: Facebook Github Bot

Added at::Dimname (#21280)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21280
ghimport-source-id: 92184832

Differential Revision: D15698516

Pulled By: zou3519

fbshipit-source-id: 502b9b019d51dd46327e6caf2af69aa520c70cb6
上级 e27c2f14
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/Dimname.h>
#include <c10/util/Exception.h>
namespace at {
bool is_valid_identifier(const std::string& name) {
std::locale loc;
if (name.length() == 0) {
return false;
}
for (auto it = name.begin(); it != name.end(); ++it) {
if (std::isalpha(*it, loc) || *it == '_') {
continue;
}
return false;
}
return true;
}
static void check_valid_identifier(const std::string& name) {
TORCH_CHECK(
is_valid_identifier(name),
"A valid identifier must contain alphabetical characters and/or underscore, got: '",
name, "'.");
}
Dimname Dimname::fromSymbol(Symbol name) {
TORCH_INTERNAL_ASSERT(name.is_dimname());
if (name == kWildcard) {
return Dimname::wildcard();
}
const std::string delimiter = ".";
const std::string str(name.toUnqualString());
auto it = str.find(delimiter);
// Check for normal name
if (it == std::string::npos) {
check_valid_identifier(str);
return Dimname(name);
}
// Check for tagged name
auto second_dot = str.find(delimiter, it + 1);
TORCH_CHECK(
second_dot == std::string::npos,
"Invalid name '", str, "': A tagged name can only contain one '.'");
auto untagged_name = str.substr(0, it);
auto tag = str.substr(it + 1);
check_valid_identifier(untagged_name);
check_valid_identifier(tag);
return Dimname(NameType::TAGGED, name, Symbol::dimname(untagged_name));
}
Dimname Dimname::wildcard() {
static Dimname result(NameType::WILDCARD, kWildcard, kWildcard);
return result;
}
} // namespace at
#endif
#pragma once
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/core/interned_strings.h>
namespace at {
enum class NameType: uint8_t { NORMAL, WILDCARD, TAGGED };
struct CAFFE2_API Dimname {
static Dimname fromSymbol(Symbol name);
static Dimname wildcard();
NameType type() const { return type_; }
Symbol name() const { return name_; }
Symbol untagged_name() const { return untagged_name_; }
private:
Dimname(Symbol name)
: untagged_name_(name), name_(name), type_(NameType::NORMAL) {}
Dimname(NameType type, Symbol name, Symbol untagged_name)
: untagged_name_(untagged_name), name_(name), type_(type) {}
Symbol untagged_name_;
Symbol name_;
NameType type_;
// Will need more fields for other special name types.
};
static Symbol kWildcard = Symbol::dimname("*");
bool CAFFE2_API is_valid_identifier(const std::string& name);
} // namespace at
#endif
......@@ -22,6 +22,7 @@ namespace c10 {
_(namespaces, scope) \
_(namespaces, user) \
_(namespaces, _caffe2) \
_(namespaces, dimname) \
_(namespaces, namespaces) \
_(prim, Assign) \
_(prim, BroadcastingChunk) \
......@@ -204,6 +205,7 @@ namespace c10 {
_(namespaces, scope) \
_(namespaces, user) \
_(namespaces, _caffe2) \
_(namespaces, dimname) \
_(namespaces, namespaces)
#endif
......@@ -272,6 +274,9 @@ struct CAFFE2_API Symbol {
static Symbol prim(const std::string & s);
static Symbol user(const std::string & s);
static Symbol caffe2(const std::string & s);
#ifdef NAMEDTENSOR_ENABLED
static Symbol dimname(const std::string & s);
#endif
// TODO: eliminate me
static Symbol scope(const std::string & s);
......@@ -281,6 +286,9 @@ struct CAFFE2_API Symbol {
bool is_onnx() const;
bool is_user() const;
bool is_caffe2() const;
#ifdef NAMEDTENSOR_ENABLED
bool is_dimname() const;
#endif
// So we can switch on this
constexpr operator unique_t() const {
......@@ -341,12 +349,18 @@ inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualStri
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); }
#ifdef NAMEDTENSOR_ENABLED
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
#endif
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
inline bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; }
#ifdef NAMEDTENSOR_ENABLED
inline bool Symbol::is_dimname() const { return ns() == namespaces::dimname; }
#endif
} // namespace c10
......
......@@ -9,6 +9,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/apply_utils_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/basic.cpp
${CMAKE_CURRENT_SOURCE_DIR}/atest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Dimname_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/half_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/wrapdim_test.cpp
......
#ifdef NAMEDTENSOR_ENABLED
#include <gtest/gtest.h>
#include <ATen/Dimname.h>
#include <c10/util/Exception.h>
using at::is_valid_identifier;
using at::NameType;
using at::Symbol;
using at::Dimname;
TEST(DimnameTest, isValidIdentifier) {
ASSERT_TRUE(is_valid_identifier("a"));
ASSERT_TRUE(is_valid_identifier("batch"));
ASSERT_TRUE(is_valid_identifier("N"));
ASSERT_TRUE(is_valid_identifier("CHANNELS"));
ASSERT_TRUE(is_valid_identifier("foo_bar_baz"));
ASSERT_FALSE(is_valid_identifier(""));
ASSERT_FALSE(is_valid_identifier(" "));
ASSERT_FALSE(is_valid_identifier(" a "));
ASSERT_FALSE(is_valid_identifier("batch1"));
ASSERT_FALSE(is_valid_identifier("foo_bar_1"));
ASSERT_FALSE(is_valid_identifier("?"));
ASSERT_FALSE(is_valid_identifier("-"));
}
TEST(DimnameTest, wildcardName) {
Dimname wildcard = Dimname::wildcard();
ASSERT_EQ(wildcard.type(), NameType::WILDCARD);
ASSERT_EQ(wildcard.name(), Symbol::dimname("*"));
ASSERT_EQ(wildcard.untagged_name(), Symbol::dimname("*"));
}
TEST(DimnameTest, createNormalName) {
auto foo = Symbol::dimname("foo");
auto dimname = Dimname::fromSymbol(foo);
ASSERT_EQ(dimname.type(), NameType::NORMAL);
ASSERT_EQ(dimname.name(), foo);
ASSERT_EQ(dimname.untagged_name(), foo);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error);
}
TEST(DimnameTest, createTaggedName) {
auto foo_bar = Symbol::dimname("foo.bar");
auto foo = Symbol::dimname("foo");
auto dimname = Dimname::fromSymbol(foo_bar);
ASSERT_EQ(dimname.type(), NameType::TAGGED);
ASSERT_EQ(dimname.name(), foo_bar);
ASSERT_EQ(dimname.untagged_name(), foo);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname(".bar")), c10::Error);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.")), c10::Error);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.bar.baz")), c10::Error);
}
#endif
......@@ -20,6 +20,7 @@ VALGRIND=${VALGRIND:=ON}
./extension_backend_test
./xla_tensor_test
./tensor_iterator_test
./Dimname_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册