提交 7c404fa5 编写于 作者: Michael Suo's avatar Michael Suo

[jit] don't try to set training after ScriptModule has been initialized.

Now when initializing a ScriptModule during the torch.jit.load()
process, there is already a cpp module backing the thing. That means
that setting training will overwrite whatever the initialized
ScriptModule had.

This PR splits apart the common "set up internal state" part of the
Module __init__ and calls that from ScriptModule.__init__ and
Module.__init__, leaving the "nn.Module-specific" part (setting
`self.training`) for the nn.Module __init__

ghstack-source-id: 9b2ba8a15c43cf230363e4cd10ba4ad3ac4931f7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23680
上级 7a7cfcbf
......@@ -1517,7 +1517,9 @@ if _enabled:
else:
self.__dict__['_c'] = torch._C.ScriptModule(_qualified_name, _compilation_unit, True)
Module.__init__(self)
Module._construct(self)
Module.__setattr__(self, "training", True)
self._parameters = OrderedParameterDict(self._c)
self._buffers = OrderedBufferDict(self._c)
self._modules = OrderedModuleDict(self._c)
......@@ -1564,7 +1566,7 @@ if _enabled:
# to improve invocation performance
self.__dict__[attr] = script_method
return script_method
return Module.__getattr__(self, attr)
return super(ScriptModule, self).__getattr__(attr)
def __setattr__(self, attr, value):
if attr not in self._constants_set:
......
......@@ -69,6 +69,15 @@ class Module(object):
_version = 1
def __init__(self):
self._construct()
# initialize self.training separately from the rest of the internal
# state, as it is managed differently by nn.Module and ScriptModule
self.training = True
def _construct(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self._backend = thnn_backend
self._parameters = OrderedDict()
......@@ -79,7 +88,6 @@ class Module(object):
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
def forward(self, *input):
r"""Defines the computation performed at every call.
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册