提交 c384fbf4 编写于 作者: Wanchao Liang's avatar Wanchao Liang 提交者: Facebook Github Bot

support torch._C._get_tracing_state in script

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23248

Test Plan: Imported from OSS

Differential Revision: D16466588

Pulled By: wanchaol

fbshipit-source-id: 3c3d5dec2cea2f9cb080eadaef457cc62ac3fbe0
上级 e1f89859
......@@ -7809,6 +7809,17 @@ a")
f = io.BytesIO()
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
def test_script_get_tracing_state(self):
def test_if_tracing(x):
if torch._C._get_tracing_state():
return x + 1
else:
return x - 1
inp = torch.randn(3, 3)
self.checkScript(test_if_tracing, (inp,))
def test_script_outputs(self):
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
@torch.jit.script
......
......@@ -464,6 +464,13 @@ RegisterOperators reg({
"Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)",
[](Stack& stack) { return 0; },
aliasAnalysisFromSchema()),
Operator(
"aten::_get_tracing_state() -> bool",
[](Stack& stack) {
push(stack, false);
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)",
[](Stack& stack) {
......
......@@ -2003,6 +2003,7 @@ def _get_builtin_table():
(torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
(torch.nn.utils.rnn.get_packed_sequence, "aten::_pack_sequence"),
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
(warnings.warn, "aten::warn"),
]
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册