提交 65a89472 编写于 作者: Michael Suo's avatar Michael Suo 提交者: Facebook Github Bot

Put all modules in the global Python CU

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23154

Test Plan: Imported from OSS

Differential Revision: D16441913

Pulled By: suo

fbshipit-source-id: a79f2c3e06a33cbd79b2e3333f16c069f356f451
上级 e366af7d
......@@ -108,7 +108,7 @@ struct TORCH_API CompilationUnit {
std::shared_ptr<Graph> graph,
bool shouldMangle = false) {
if (shouldMangle) {
name = c10::QualifiedName(name.prefix(), mangle(name.name()));
name = mangle(name);
}
auto fn = torch::make_unique<Function>(
std::move(name), std::move(graph), nullptr);
......@@ -210,6 +210,16 @@ struct TORCH_API CompilationUnit {
classDict_.clear();
}
// [name mangling] All code objects must have a unique qualified name in a
// CompilationUnit. In Python, sometimes functions won't have unique qualified
// name (for example, nested functions). So we mangle Python functions to
// ensure that they are uniquely named.
//
// We also use mangling to distinguish different Module instances. Since each
// Module is a singleton class instance, different instances of the same
// Python Module will have different types but the same qualified name.
c10::QualifiedName mangle(const c10::QualifiedName& name) const;
private:
std::unique_ptr<Function> define(
const c10::optional<c10::QualifiedName>& prefix,
......@@ -241,12 +251,7 @@ struct TORCH_API CompilationUnit {
// module's compilation unit.
std::vector<c10::NamedTypePtr> classes_;
// [name mangling] All code objects must have a unique qualified name in a
// CompilationUnit. In Python, sometimes functions won't have unique qualified
// name (for example, nested functions). So we mangle Python functions to
// ensure that they are uniquely named.
mutable size_t mangleIndex_ = 0;
std::string mangle(const std::string& name) const;
};
} // namespace script
......
......@@ -3069,22 +3069,29 @@ CompilationUnit::CompilationUnit(const std::string& source)
define(c10::nullopt, source, nativeResolver(), nullptr);
}
// Mangle a qualified name so that it is globally unique.
std::string CompilationUnit::mangle(const std::string& name) const {
c10::QualifiedName CompilationUnit::mangle(const c10::QualifiedName& name) const {
static const std::string manglePrefix = "___torch_mangle_";
std::string mangledName;
auto pos = name.find(manglePrefix);
if (pos != std::string::npos) {
// If the name is already mangled, avoid re-appending the prefix.
mangledName.reserve(name.size());
// Append the part of the name up to the end of the prefix
mangledName.append(name, 0, pos);
mangledName.append(std::to_string(mangleIndex_++));
} else {
mangledName = c10::str(name, manglePrefix, std::to_string(mangleIndex_++));
}
return mangledName;
std::vector<std::string> atoms = name.atoms();
// Search for an already-existing mangle namespace.
// If the name is already mangled, just bump the integer.
for (auto& atom : atoms) {
auto pos = atom.find(manglePrefix);
if (pos != std::string::npos) {
std::string newAtom;
newAtom.reserve(atom.size());
// Append the part of the name up to the end of the prefix
newAtom.append(atom, 0, pos);
newAtom.append(std::to_string(mangleIndex_++));
atom = newAtom;
return QualifiedName(atoms);
}
}
// Otherwise add a mangle namespace right before the basename
TORCH_INTERNAL_ASSERT(!atoms.empty());
atoms.insert(atoms.end() - 1, manglePrefix + std::to_string(mangleIndex_++));
return QualifiedName(atoms);
}
std::unique_ptr<Function> CompilationUnit::define(
......@@ -3122,8 +3129,7 @@ std::unique_ptr<Function> CompilationUnit::define(
// If `shouldMangle` is set, we should generate a unique name for this
// function if there is already an existing one.
if (auto fn = find_function(name)) {
auto newBase = mangle(name.name());
name = QualifiedName(name.prefix(), newBase);
name = mangle(name);
}
}
auto fn = torch::make_unique<Function>(
......
......@@ -347,7 +347,7 @@ void initJitScriptBindings(PyObject* module) {
// Methods here are prefixed with _ since they should not be
// public.
py::class_<Module>(m, "ScriptModule")
.def(py::init<std::string, std::shared_ptr<CompilationUnit>>())
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
.def(
"save",
[](Module& m,
......@@ -390,7 +390,7 @@ void initJitScriptBindings(PyObject* module) {
for (auto& callback : rcbs) {
resolvers.push_back(pythonResolver(callback));
}
const auto prefix = QualifiedName(m.name());
const auto& prefix = m.name();
const auto self = ModuleSelf(m, py_m);
m.class_compilation_unit()->define(prefix, defs, resolvers, &self);
// Stitch in default arguments for each Def if provided
......@@ -418,11 +418,6 @@ void initJitScriptBindings(PyObject* module) {
},
py::keep_alive<0, 1>())
.def("_register_parameter", &Module::register_parameter)
.def(
"_get_functions",
[](Module& self) {
return self.class_compilation_unit()->get_functions();
})
.def(
"_register_attribute",
[](Module& self, std::string name, TypePtr type, py::object value) {
......
......@@ -14,20 +14,30 @@ namespace script {
static ModulePtr create_module_object(
c10::QualifiedName class_name,
std::shared_ptr<CompilationUnit> cu) {
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false) {
if (shouldMangle && cu->get_class(class_name) != nullptr) {
class_name = cu->mangle(class_name);
}
auto cls = ClassType::create(std::move(class_name), cu, /*is_module=*/true);
cu->register_class(cls);
return c10::ivalue::Object::create(
c10::StrongTypePtr(std::move(cu), std::move(cls)), 0);
}
Module::Module(c10::QualifiedName class_name)
: module_value_(create_module_object( std::move(class_name), std::make_shared<CompilationUnit>())) {}
: module_value_(create_module_object(
std::move(class_name),
std::make_shared<CompilationUnit>())) {}
Module::Module(
c10::QualifiedName class_name,
std::shared_ptr<CompilationUnit> cu)
: module_value_(
create_module_object(std::move(class_name), std::move(cu))) {}
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle)
: module_value_(create_module_object(
std::move(class_name),
std::move(cu),
shouldMangle)) {}
ModulePtr Module::module_object() const {
if (!module_value_) {
......@@ -257,14 +267,14 @@ void Module::copy_into(
}
}
for (auto& fn : class_compilation_unit()->get_functions()) {
curr.clone_method(*this, fn->qualname(), type_remap);
for (auto& fn : type()->methods()) {
curr.clone_method(*this, *fn, type_remap);
}
}
void Module::clone_method(
const Module& orig,
const QualifiedName& orig_method_name,
const Function& method,
const std::unordered_map<TypePtr, TypePtr>& type_remap) {
// type remapping - when we copy method implementations from one module
// singleton to another, we need to update the types of the self arguments
......@@ -282,14 +292,13 @@ void Module::clone_method(
return in;
return it->second;
};
const Function& fn =
orig.class_compilation_unit()->get_function(orig_method_name);
auto graph = fn.graph()->copy();
auto graph = method.graph()->copy();
graph->remapTypes(type_remap_fn);
auto schema = fn.getSchema().cloneWithRemappedTypes(type_remap_fn);
const auto this_method_name = getNameForMethod(orig_method_name.name());
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
const auto this_method_name = getNameForMethod(method.name());
auto copied =
class_compilation_unit()->create_function(this_method_name, graph);
type()->addMethod(copied);
copied->setSchema(std::move(schema));
}
......@@ -305,8 +314,7 @@ void Module::clone_method(const Module& orig, const std::string& name) {
to_scan.emplace_back(s.to_module(), entry.second.get_module(s.name()));
}
}
const auto orig_method_name = QualifiedName(orig.name(), name);
return clone_method(orig, orig_method_name, type_remap);
return clone_method(orig, orig.get_method(name).function(), type_remap);
}
void Module::train(bool on) {
......
......@@ -109,7 +109,10 @@ struct TORCH_API Method {
struct TORCH_API Module {
explicit Module(c10::QualifiedName class_name);
Module(c10::QualifiedName, std::shared_ptr<CompilationUnit> cu);
Module(
c10::QualifiedName,
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false);
// module_value_ null and will be lazily initialized if is needed
Module() {}
Module(ModulePtr module_value) : module_value_(std::move(module_value)) {}
......@@ -204,7 +207,7 @@ struct TORCH_API Module {
const std::vector<Method> get_methods() const {
return fmap(
class_compilation_unit()->get_functions(),
type()->methods(),
[&](Function* func) {
return Method(module_object(), func);
});
......@@ -230,9 +233,11 @@ struct TORCH_API Module {
return c10::nullopt;
}
c10::optional<Method> find_method(const std::string& basename) const {
if (const auto fn = class_compilation_unit()->find_function(
getNameForMethod(basename))) {
return Method(module_object(), fn);
for (Function* fn : type()->methods()) {
if (fn->name() == basename) {
return Method(module_object(), fn);
}
}
return c10::nullopt;
}
......@@ -314,7 +319,7 @@ struct TORCH_API Module {
void clone_method(const Module& orig, const std::string& name);
at::optional<EntityType> kind_of(const std::string& name) const {
if (class_compilation_unit()->find_function(getNameForMethod(name))) {
if (find_method(name)) {
return EntityType::METHOD;
}
if (auto offset = type()->findAttributeSlot(name)) {
......@@ -355,7 +360,7 @@ struct TORCH_API Module {
private:
void clone_method(
const Module& orig,
const QualifiedName& orig_method_name,
const Function& method,
const std::unordered_map<TypePtr, TypePtr>& type_remap);
c10::QualifiedName getNameForMethod(std::string basename) const {
......
......@@ -506,6 +506,7 @@ def _check_trace(check_inputs, func, traced_func, check_tolerance,
check_trace=False,
_force_outplace=force_outplace,
_module_class=_module_class,
_compilation_unit=torch._C.CompilationUnit(),
)
check_mod_func = check_mod._c._get_method(traced_func.name)
inputs = inputs[traced_func.name]
......@@ -654,10 +655,10 @@ def make_tuple(example_inputs):
return example_inputs
def make_module(mod, _module_class):
def make_module(mod, _module_class, _compilation_unit):
if _module_class is None:
_module_class = TopLevelTracedModule
return _module_class(mod)
return _module_class(mod, _compilation_unit=_compilation_unit)
def wrap_check_inputs(check_inputs):
if check_inputs is None:
......@@ -672,7 +673,8 @@ def trace(func,
check_inputs=None,
check_tolerance=1e-5,
_force_outplace=False,
_module_class=None):
_module_class=None,
_compilation_unit=_python_cu):
"""
Trace a function and return an executable ``ScriptModule`` or ``torch.jit._C.Function``
that will be optimized using just-in-time compilation.
......@@ -805,7 +807,8 @@ def trace_module(mod,
check_inputs=None,
check_tolerance=1e-5,
_force_outplace=False,
_module_class=None):
_module_class=None,
_compilation_unit=_python_cu):
"""
Trace a module and return an executable ``ScriptModule`` that will be optimized
using just-in-time compilation.
......@@ -871,7 +874,6 @@ def trace_module(mod,
module = torch.jit.trace_module(n, inputs)
"""
if not _enabled:
return mod
if optimize is not None:
......@@ -885,7 +887,7 @@ def trace_module(mod,
if not isinstance(inputs, dict):
raise AttributeError("expected a dictionary of (method_name, input) pairs")
module = make_module(mod, _module_class)
module = make_module(mod, _module_class, _compilation_unit)
for method_name, example_inputs in inputs.items():
# this is needed since Module.__call__ sets up some extra tracing
......@@ -1471,7 +1473,7 @@ if _enabled:
if _qualified_name is None:
_qualified_name = type(self).__name__
if _compilation_unit is None:
_compilation_unit = torch._C.CompilationUnit()
_compilation_unit = _python_cu
if optimize is not None:
warnings.warn("`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead")
......@@ -1480,7 +1482,7 @@ if _enabled:
if _cpp_module is not None:
self.__dict__['_c'] = _cpp_module
else:
self.__dict__['_c'] = torch._C.ScriptModule(_qualified_name, _compilation_unit)
self.__dict__['_c'] = torch._C.ScriptModule(_qualified_name, _compilation_unit, True)
Module.__init__(self)
self._parameters = OrderedParameterDict(self._c)
......@@ -1801,9 +1803,10 @@ for name, method in _get_methods(torch.nn.Module):
class TracedModule(ScriptModule):
__frozen = False
def __init__(self, orig, id_set=None):
def __init__(self, orig, id_set=None, _compilation_unit=None):
# XXX: orig can be a nn.Module or a function!
super(TracedModule, self).__init__()
super(TracedModule, self).__init__(_qualified_name=_jit_internal._qualified_name(orig.__class__),
_compilation_unit=_compilation_unit)
if id_set is None:
id_set = set()
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册