提交 0d6eb209 编写于 作者: Richard Zou's avatar Richard Zou 提交者: Facebook Github Bot

Expose torch.empty(sizes, *, names, ...) to Python (#21648)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21648
ghimport-source-id: 583f155c

Differential Revision: D15804482

Pulled By: zou3519

fbshipit-source-id: f86520dda479100be2a752e4db8a902167413a83
上级 71082187
import unittest
from common_utils import TestCase, run_tests
from common_cuda import TEST_CUDA
import torch
import sys
def namedtensor_enabled():
......@@ -10,11 +12,98 @@ skipIfNamedTensorDisabled = \
unittest.skipIf(not namedtensor_enabled(),
'PyTorch not compiled with namedtensor support')
def pass_name_to_python_arg_parser(name):
x = torch.empty(2, names=(name,))
class TestNamedTensor(TestCase):
@skipIfNamedTensorDisabled
def test_trivial(self):
pass
def _test_factory(self, factory, device):
x = factory([], device=device)
self.assertEqual(x.names, ())
x = factory(1, 2, 3, device=device)
self.assertEqual(x.names, (None, None, None))
x = factory(1, 2, 3, names=None, device=device)
self.assertEqual(x.names, (None, None, None))
x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device)
self.assertEqual(x.names, ('N', 'T', 'D'))
x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
self.assertEqual(x.names, ('N', None, 'D'))
with self.assertRaisesRegex(RuntimeError,
'must contain alphabetical characters and/or underscore'):
x = factory(2, names=('?',), device=device)
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
x = factory(2, 1, names=('N',), device=device)
with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'):
x = factory(2, 1, names='N', device=device)
@skipIfNamedTensorDisabled
def test_empty(self):
self._test_factory(torch.empty, 'cpu')
@skipIfNamedTensorDisabled
@unittest.skipIf(not TEST_CUDA, 'no CUDA')
def test_empty_cuda(self):
self._test_factory(torch.empty, 'cuda')
@skipIfNamedTensorDisabled
def test_using_seen_interned_string_doesnt_bump_refcount(self):
def see_name():
seen_name = 'N'
pass_name_to_python_arg_parser(seen_name)
see_name()
seen_name = 'N'
old_refcnt = sys.getrefcount(seen_name)
pass_name_to_python_arg_parser(seen_name)
new_refcnt = sys.getrefcount(seen_name)
self.assertEqual(new_refcnt, old_refcnt)
@skipIfNamedTensorDisabled
def test_using_unseen_interned_string_bumps_refcount_permanently(self):
# Please don't use this as a name in a different test.
unseen_name = 'abcdefghi'
old_refcnt = sys.getrefcount(unseen_name)
pass_name_to_python_arg_parser(unseen_name)
new_refcnt = sys.getrefcount(unseen_name)
self.assertEqual(new_refcnt, old_refcnt + 1)
@skipIfNamedTensorDisabled
def test_using_unseen_uninterned_string_refcounts(self):
# Please don't use this as a name in a different test.
# non-compile-time constants are not interned
unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl'])
interned_unseen_name = 'abcdefghijkl'
self.assertFalse(unseen_name is interned_unseen_name)
old_uninterned_refcnt = sys.getrefcount(unseen_name)
old_interned_refcnt = sys.getrefcount(interned_unseen_name)
pass_name_to_python_arg_parser(unseen_name)
new_uninterned_refcnt = sys.getrefcount(unseen_name)
new_interned_refcnt = sys.getrefcount(interned_unseen_name)
# Internally, PyTorch should not hold a reference to the uninterned string
self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)
# Instead, we should hold a new reference to the interned version.
self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)
if __name__ == '__main__':
run_tests()
......@@ -42,7 +42,6 @@ SKIP_PYTHON_BINDINGS_SIGNATURES = [
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
'empty(IntArrayRef, DimnameList?, TensorOptions)',
]
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
......@@ -163,7 +162,6 @@ const auto options = TensorOptions()
.pinned_memory(${pin_memory});
""")
def should_generate_python_binding(declaration):
name = declaration['name']
for pattern in SKIP_PYTHON_BINDINGS:
......@@ -290,6 +288,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
'const Type &': 'scalartype',
'const THPLayout &': 'layout',
'const Device &': 'device',
'c10::optional<DimnameList>': 'toDimnameListOptional',
'c10::optional<ScalarType>': 'scalartypeOptional',
'c10::optional<Scalar>': 'scalarOptional',
'c10::optional<int64_t>': 'toInt64Optional',
......@@ -346,6 +345,19 @@ def create_python_bindings(python_functions, has_self, is_module=False):
if type_args and len(outputs) > 1:
raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")
def unpack_variable(name, unpack_expr, typename):
# optional<ArrayRef<T>> are special. The PythonArgParser returns an
# optional<vector<T>>, which cannot be implictly converted to
# optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
if typename == 'c10::optional<DimnameList>':
result = """\
auto __{name} = {expr};
c10::optional<{typ}> {name} = __{name} ? c10::make_optional({typ}(__{name}.value())) : c10::nullopt;
""".format(name=name, expr=unpack_expr, typ='DimnameList')
return [line.strip() for line in result.split('\n')]
return ['auto {} = {};'.format(name, unpack_expr)]
def parse_arg(arg, arg_index, unpack_args=False):
name = arg['name']
typename = arg['type']
......@@ -365,7 +377,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
expr = 'r.{}({})'.format(unpack, arg_index)
if unpack_args:
body.append('auto {} = {};'.format(name, expr))
body.extend(unpack_variable(name, expr, typename))
expr = name
dispatch_type = typename
......@@ -633,6 +645,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
'simple_type': 'bool',
}
python_binding_arguments.append(requires_grad_arg)
return python_binding_arguments
def emit_namedtuple_return_type_def(declaration, next_index):
......
......@@ -170,6 +170,7 @@ def add_torch_libs():
"torch/csrc/MemoryFormat.cpp",
"torch/csrc/Module.cpp",
"torch/csrc/PtrWrapper.cpp",
"torch/csrc/python_dimname.cpp",
"torch/csrc/Size.cpp",
"torch/csrc/Storage.cpp",
"torch/csrc/TypeInfo.cpp",
......
......@@ -126,6 +126,7 @@ def type_to_python(typename, size=None):
'void*': '_int', # data_ptr
'void': 'None',
'std::string': 'str',
'DimnameList': 'List[Union[str, None]]',
}[typename]
return typename
......
......@@ -51,6 +51,7 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/Generator.cpp
${TORCH_SRC_DIR}/csrc/Layout.cpp
${TORCH_SRC_DIR}/csrc/MemoryFormat.cpp
${TORCH_SRC_DIR}/csrc/python_dimname.cpp
${TORCH_SRC_DIR}/csrc/Module.cpp
${TORCH_SRC_DIR}/csrc/PtrWrapper.cpp
${TORCH_SRC_DIR}/csrc/Size.cpp
......
......@@ -311,6 +311,37 @@ PyObject *THPVariable_get_ndim(THPVariable *self)
END_HANDLE_TH_ERRORS
}
#ifdef NAMEDTENSOR_ENABLED
PyObject *THPVariable_get_names(THPVariable *self)
{
HANDLE_TH_ERRORS
// The long-term plan is to return a list of (python) torch.Dimname.
// However, for now, return a list of string.
size_t size = self->cdata.dim();
THPObjectPtr tuple(PyTuple_New(size));
if (!tuple) throw python_error();
if (!self->cdata.is_named()) {
for (size_t i = 0; i < size; ++i) {
PyTuple_SET_ITEM(tuple.get(), i, Py_None);
}
return tuple.release();
}
const auto dimnames = self->cdata.names().value();
for (size_t i = 0; i < size; ++i) {
PyObject* str = Py_None;
if (dimnames[i].type() != at::NameType::WILDCARD) {
str = THPUtils_packString(dimnames[i].name().toUnqualString());
if (!str) throw python_error();
}
PyTuple_SET_ITEM(tuple.get(), i, str);
}
return tuple.release();
END_HANDLE_TH_ERRORS
}
#endif
int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
{
HANDLE_TH_ERRORS
......@@ -452,6 +483,9 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
#ifdef NAMEDTENSOR_ENABLED
{"names", (getter)THPVariable_get_names, nullptr, nullptr, nullptr},
#endif
{nullptr}
};
......
#ifdef NAMEDTENSOR_ENABLED
#include <torch/csrc/python_dimname.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/python_strings.h>
#include <c10/util/flat_hash_map.h>
namespace torch {
struct InternedStringsTable {
InternedStringsTable() = default;
~InternedStringsTable();
InternedStringsTable(const InternedStringsTable &) = delete;
InternedStringsTable& operator =(InternedStringsTable const&) = delete;
InternedStringsTable(InternedStringsTable&&) = delete;
InternedStringsTable& operator=(InternedStringsTable&&) = delete;
at::optional<at::Dimname> lookup(PyObject* obj);
// Precondition: obj is an interned python string.
void addMapping(PyObject* obj, at::Dimname dimname);
private:
ska::flat_hash_map<PyObject*,at::Dimname> py_interned_string_to_dimname_;
};
InternedStringsTable kPyInternedStringToDimname;
InternedStringsTable::~InternedStringsTable() {
for (auto it = py_interned_string_to_dimname_.begin();
it != py_interned_string_to_dimname_.end(); ++it) {
// See Note [References to python interned strings]
Py_DECREF(it->first);
}
}
at::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
auto it = py_interned_string_to_dimname_.find(obj);
if (it == py_interned_string_to_dimname_.end()) {
return at::nullopt;
}
return it->second;
}
void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
// Note [References to python interned strings]
// If a Python interned string has no references to it, then it gets
// deallocated, invalidating this mapping. Let's immortalize the string by
// holding a refcount to it and releasing it in the destructor
Py_INCREF(obj);
py_interned_string_to_dimname_.emplace(obj, dimname);
}
} // namespace torch
at::Dimname THPDimname_parse(PyObject* obj) {
if (obj == Py_None) {
return at::Dimname::wildcard();
}
if (!THPUtils_checkString(obj)) {
throw torch::TypeError("expected None or string for Dimname but got %s", Py_TYPE(obj)->tp_name);
}
if (!THPUtils_isInterned(obj)) {
// internStringInPlace decrefs obj and increfs the result. Because we're
// not actually returning the result to the user, we need to undo these.
// See https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace
Py_INCREF(obj);
THPUtils_internStringInPlace(&obj);
Py_DECREF(obj);
}
auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
if (maybeDimname) {
return *maybeDimname;
}
const auto name = THPUtils_unpackString(obj);
auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
torch::kPyInternedStringToDimname.addMapping(obj, dimname);
return dimname;
}
#endif
#pragma once
#ifdef NAMEDTENSOR_ENABLED
#include <torch/csrc/python_headers.h>
#include <ATen/Dimname.h>
at::Dimname THPDimname_parse(PyObject* obj);
#endif
......@@ -32,6 +32,8 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
{"Device", ParameterType::DEVICE},
{"std::string", ParameterType::STRING},
{"Dimname", ParameterType::DIMNAME},
{"DimnameList", ParameterType::DIMNAME_LIST},
};
// Default arg name translations for compatibility with NumPy.
......@@ -157,6 +159,7 @@ bool FunctionParameter::check(PyObject* obj) {
}
return false;
}
case ParameterType::DIMNAME_LIST:
case ParameterType::TENSOR_LIST: return six::isTuple(obj) || PyList_Check(obj);
case ParameterType::INT_LIST: {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
......@@ -196,6 +199,9 @@ std::string FunctionParameter::type_name() const {
case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
case ParameterType::DEVICE: return "torch.device";
case ParameterType::STRING: return "str";
#ifdef NAMEDTENSOR_ENABLED
case ParameterType::DIMNAME_LIST: return "tuple of names";
#endif
default: throw std::runtime_error("unknown parameter type");
}
}
......
......@@ -51,6 +51,9 @@
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/jit/ir.h>
#ifdef NAMEDTENSOR_ENABLED
#include <torch/csrc/python_dimname.h>
#endif
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/numpy_stub.h>
#include <torch/csrc/utils/object_ptr.h>
......@@ -72,7 +75,8 @@ namespace torch {
enum class ParameterType {
TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
DIMNAME, DIMNAME_LIST,
};
struct FunctionParameter;
......@@ -136,6 +140,9 @@ struct PythonArgs {
inline at::Device device(int i);
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
inline c10::optional<at::Device> deviceOptional(int i);
#ifdef NAMEDTENSOR_ENABLED
inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
#endif
inline at::MemoryFormat toMemoryFormat(int i);
inline std::string string(int i);
inline PyObject* pyobject(int i);
......@@ -405,6 +412,22 @@ inline c10::optional<at::Device> PythonArgs::deviceOptional(int i) {
return device(i);
}
#ifdef NAMEDTENSOR_ENABLED
inline c10::optional<std::vector<at::Dimname>> PythonArgs::toDimnameListOptional(int i) {
if (!args[i]) return c10::nullopt;
PyObject* arg = args[i];
auto tuple = PyTuple_Check(arg);
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
std::vector<at::Dimname> res;
res.reserve(size);
for (int idx = 0; idx < size; idx++) {
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
res.push_back(THPDimname_parse(obj));
}
return res;
}
#endif
inline at::MemoryFormat PythonArgs::toMemoryFormat(int i) {
if (!args[i]) return at::MemoryFormat::Any;
TORCH_CHECK(THPMemoryFormat_Check(args[i]), "memory_format arg must be an instance of the torch.memory_format");
......
......@@ -65,3 +65,21 @@ inline PyObject* THPUtils_internString(const std::string& str) {
return PyUnicode_InternFromString(str.c_str());
#endif
}
// Precondition: THPUtils_checkString(obj) must be true
inline bool THPUtils_isInterned(PyObject* obj) {
#if PY_MAJOR_VERSION == 2
return PyString_CHECK_INTERNED(obj);
#else
return PyUnicode_CHECK_INTERNED(obj);
#endif
}
// Precondition: THPUtils_checkString(obj) must be true
inline void THPUtils_internStringInPlace(PyObject** obj) {
#if PY_MAJOR_VERSION == 2
PyString_InternInPlace(obj);
#else
PyUnicode_InternInPlace(obj);
#endif
}
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册