提交 696642ae 编写于 作者: davidriazati's avatar davidriazati 提交者: Facebook Github Bot

Change docs to use recursive script API (#21612)

Summary:
Use the recursive script API in the existing docs

TODO:
* Migration guide for 1.1 -> 1.2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21612

Pulled By: driazati

Differential Revision: D16553734

fbshipit-source-id: fb6be81a950224390bd5d19b9b3de2d97b3dc515
上级 bfee46f8
......@@ -7,18 +7,21 @@ TorchScript
.. currentmodule:: torch.jit
TorchScript is a way to create serializable and optimizable models from PyTorch code.
Any code written in TorchScript can be saved from a Python
Any TorchScript program can be saved from a Python
process and loaded in a process where there is no Python dependency.
We provide tools to incrementally transition a model from a pure Python program
to a TorchScript program that can be run independently from Python, for instance, in a standalone C++ program.
This makes it possible to train models in PyTorch using familiar tools and then export
the model via TorchScript to a production environment where it is not a good idea to run models as Python programs
to a TorchScript program that can be run independently from Python, such as in a standalone C++ program.
This makes it possible to train models in PyTorch using familiar tools in Python and then export
the model via TorchScript to a production environment where Python programs may be disadvantageous.
for performance and multi-threading reasons.
Creating TorchScript Code
--------------------------
.. autofunction:: script
.. autofunction:: trace
.. autoclass:: ScriptModule
:members:
......@@ -27,14 +30,13 @@ Creating TorchScript Code
.. autofunction:: load
.. autofunction:: trace
Mixing Tracing and Scripting
----------------------------
In many cases either tracing or scripting is an easier approach for converting a model to TorchScript.
We allow you to compose tracing and scripting to suit the particular requirements
Tracing and scripting can be composed to suit the particular requirements
of a part of a model.
Scripted functions can call traced functions. This is particularly useful when you need
......@@ -77,7 +79,7 @@ Example::
traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
This composition also works for ``ScriptModule``\s as well, where it can be used to generate
This composition also works for ``nn.Module``\s as well, where it can be used to generate
a submodule using tracing that can be called from the methods of a script module:
Example::
......@@ -85,7 +87,7 @@ Example::
import torch
import torchvision
class MyScriptModule(torch.jit.ScriptModule):
class MyScriptModule(torch.nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
......@@ -93,10 +95,11 @@ Example::
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
@torch.jit.script_method
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
TorchScript Language Reference
-------------------------------
......
......@@ -171,14 +171,14 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
Save an offline version of this module for use in a separate process. The saved
module serializes all of the methods, submodules, parameters, and attributes of this
module. It can be loaded into the C++ API using ``torch::jit::load(filename)`` or into the Python
API with ``torch.jit.load(filename)``.
API with :func:`load <torch.jit.load>`.
To be able to save a module, it must not make any calls to native Python functions.
This means that all submodules must be subclasses of ``torch.jit.ScriptModule`` as well.
.. DANGER::
All modules, no matter their device, are always loaded onto the CPU during loading.
This is different from :func:`torch.load`'s semantics and may change in the future.
This is different from :func:`load <torch.jit.load>`'s semantics and may change in the future.
Arguments:
m: a ScriptModule to save
......@@ -195,7 +195,15 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
Example: ::
m = torch.jit.ScriptModule()
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
m = torch.jit.script(MyModule())
# Save to file
torch.jit.save(m, 'scriptmodule.pt')
......@@ -1051,7 +1059,77 @@ def _compile_and_register_class(obj, rcb, qualified_name):
_add_script_class(obj, qualified_name)
def script(obj, optimize=None, _frames_up=0, _rcb=None):
def script(obj, optimize=True, _frames_up=0, _rcb=None):
r"""
Scripting a function or ``nn.Module`` will inspect the source code, compile
it as TorchScript code using the TorchScript compiler, and return a ``ScriptModule`` or
``torch._C.Function``.
**Scripting a function**
The ``@torch.jit.script`` decorator will construct a ``torch._C.Function``.
Example (scripting a function)::
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
**Scripting an nn.Module**
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
features supported in TorchScript, no changes to the original module code should be necessary.
Example (scripting a simple module with a Parameter)::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
# This parameter will be copied to the new ScriptModule
self.weight = torch.nn.Parameter(torch.rand(N, M))
# When this submodule is used, it will be compiled
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mv(input)
# This calls the `forward` method of the `nn.Linear` module, which will
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule())
Example (scripting a module with traced submodules)::
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
To compile a method other than ``forward`` (and recursively compile anything it calls), add
the ``@torch.jit.export`` decorator to the method.
"""
if not _enabled:
return obj
if optimize is not None:
......@@ -1414,60 +1492,15 @@ if _enabled:
**Scripting:**
You can write TorchScript code directly using Python syntax. You do this
using the ``@torch.jit.script`` decorator (for functions) or
``@torch.jit.script_method`` decorator (for methods) on subclasses of
``ScriptModule``. With this decorator the body of the annotated function is
directly translated into TorchScript. TorchScript itself is a subset of
the Python language, so not all features in Python work, but we provide
enough functionality to compute on tensors and do control-dependent
operations.
Example (scripting a function)::
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
.. note::
A ``@torch.jit.script`` decorator will construct a ``ScriptModule`` with a single
``forward`` method that implements the function. The resulting
``ScriptModule`` has no parameters or attributes.
Example (scripting a simple module with a Parameter)::
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
return self.weight.mv(input)
Example (scripting a module with traced submodules)::
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
@torch.jit.script_method
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
using the ``@torch.jit.script`` decorator for functions and modules. You can
also call ``torch.jit.script`` directly with the function or module you wish to
compile. On functions, the body of the function is compiled to TorchScript. If
applied to an ``nn.Module``, by default the ``forward`` method and any methods it
calls are compiled, and all buffer and Parameters of the original module are copied
to a new ``ScriptModule``. You should not need to construct a ``ScriptModule`` manually.
TorchScript itself is a subset of the Python language, so not all
features in Python work, but we provide enough functionality to compute on
tensors and do control-dependent operations.
"""
def __init__(self, optimize=None, _qualified_name=None, _compilation_unit=None, _cpp_module=None):
if _qualified_name is None:
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册