提交 52b95fd4 编写于 作者: davidriazati's avatar davidriazati 提交者: Facebook Github Bot

Include recursive class compilations in error call stack (#23454)

Summary:
Previously these were left out which would lead to confusing messages,
now it looks something like:

```
torch.jit.frontend.UnsupportedNodeError: import statements aren't
supported
:
at ../test.py:13:9
    def bad_fn(self):
        import pdb
        ~~~~~~ <--- HERE
'__torch__.X' is being compiled since it was called from 'fn'
at ../test.py:16:12
def fn(x):
    return X(10)
           ~~~~ <--- HERE
```

Fixes #23453
](https://our.intern.facebook.com/intern/diff/16526027/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23454

Pulled By: driazati

Differential Revision: D16526027

fbshipit-source-id: 109f2968430dbf51ee91b1b3409badfd557d19a4
上级 696642ae
......@@ -13491,6 +13491,18 @@ a")
self.assertTrue('forward' in dir(M()))
@unittest.skipIf(PY2, "kwarg expansion requires Python 3")
def test_kwarg_expansion_error(self):
@torch.jit.ignore
def something_else(h, i):
pass
def fn(x):
something_else(**x)
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"):
torch.jit.script(fn)
def test_inferred_error_msg(self):
"""
Test that when we get a type mismatch on a function where we inferred
......@@ -13675,6 +13687,22 @@ class TestRecursiveScript(JitTestCase):
t = torch.ones(2, 2)
self.assertEqual(a_script_fn(t, t, t), t + t + t)
def test_error_stack_class(self):
class X(object):
def bad_fn(self):
import pdb # noqa
def fn(x):
return X(10)
try:
torch.jit.script(fn)
except Exception as e:
checker = FileCheck()
checker.check("import statements")
checker.check("is being compiled since it was called from")
checker.run(str(e))
def test_module_basic(self):
class Other(torch.nn.Module):
__constants__ = ['x']
......
......@@ -504,23 +504,34 @@ std::shared_ptr<SugaredValue> toSugaredValue(
py::str qualifiedName =
py::module::import("torch.jit").attr("_qualified_name")(obj);
auto pyCu = get_python_cu();
if (auto classType = pyCu->get_class(c10::QualifiedName(qualifiedName))) {
auto qualname = c10::QualifiedName(qualifiedName);
if (auto classType = pyCu->get_class(qualname)) {
return std::make_shared<ClassValue>(classType);
} else {
// If we can't get the source code for the type, it's implemented in C and
// probably part of the standard library, so give up and leave it as a
// call to Python
bool can_compile_class = py::cast<bool>(
py::module::import("torch._jit_internal").attr("can_compile_class")(obj));
bool can_compile_class =
py::cast<bool>(py::module::import("torch._jit_internal")
.attr("can_compile_class")(obj));
if (can_compile_class) {
// Register class
auto rcb = py::module::import("torch._jit_internal")
.attr("createResolutionCallbackForClassMethods")(obj);
// We're starting a new compilation, so update the error call stack in
// case it fails
ErrorReport::CallStack::push_function(qualname.name());
ErrorReport::CallStack::update_pending_range(loc);
py::module::import("torch.jit")
.attr("_compile_and_register_class")(obj, rcb, qualifiedName);
// Compilation was successful, so pop this entry off the stack
ErrorReport::CallStack::pop_function();
// Return class
auto newClassType = pyCu->get_class(c10::QualifiedName(qualifiedName));
auto newClassType = pyCu->get_class(qualname);
AT_ASSERT(
newClassType,
"Class '",
......
......@@ -1132,6 +1132,7 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
"""
if not _enabled:
return obj
if optimize is not None:
warnings.warn("`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead")
......
......@@ -120,7 +120,6 @@ def get_type_line(source):
type_lines = list(filter(lambda line: type_comment in line[1], lines))
lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
if len(type_lines) == 0:
type_pattern = re.compile('#[\t ]*type[\t ]*:')
wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
......
......@@ -111,7 +111,7 @@ class UnsupportedNodeError(NotSupportedError):
offending_node.col_offset + range_len)
feature_name = pretty_node_names.get(node_type, node_type.__name__)
msg = "{} aren't supported".format(feature_name)
super(NotSupportedError, self).__init__(source_range, msg)
super(UnsupportedNodeError, self).__init__(source_range, msg)
class FrontendTypeError(FrontendError):
......@@ -435,6 +435,8 @@ class ExprBuilder(Builder):
for kw in expr.keywords:
kw_expr = build_expr(ctx, kw.value)
# XXX: we could do a better job at figuring out the range for the name here
if not kw.arg:
raise NotSupportedError(kw_expr.range(), 'keyword-arg expansion is not supported')
kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
return Apply(func, args, kwargs)
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册