提交 6cf9ed4a 编写于 作者: Jerry Zhang's avatar Jerry Zhang 提交者: Facebook Github Bot

ConvBn2d/ConvBnReLU2d (#23357)

Summary:
Added _intrinsic.qat.ConvBn2d/_intrinsic.qat.ConvBnReLU2d.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/23357
ghstack-source-id: 87519573

Differential Revision: D16295500

fbshipit-source-id: 81e6d1d10d05bf6e343721fc5701d3d6bd7e07e6
上级 029c8e77
......@@ -228,7 +228,7 @@ class ManualConvLinearQATModel(torch.nn.Module):
class SubModForFusion(torch.nn.Module):
def __init__(self):
super(SubModForFusion, self).__init__()
self.conv = torch.nn.Conv2d(20, 20, 1)
self.conv = torch.nn.Conv2d(20, 20, 1, bias=None)
self.bn = torch.nn.BatchNorm2d(20)
def forward(self, x):
......@@ -239,9 +239,9 @@ class SubModForFusion(torch.nn.Module):
class ModForFusion(torch.nn.Module):
def __init__(self):
super(ModForFusion, self).__init__()
self.conv1 = torch.nn.Conv2d(10, 20, 5)
self.conv1 = torch.nn.Conv2d(10, 20, 5, bias=None)
self.bn1 = torch.nn.BatchNorm2d(20)
self.relu1 = torch.nn.ReLU()
self.relu1 = torch.nn.ReLU(inplace=False)
self.sub1 = SubModForFusion()
self.sub2 = SubModForFusion()
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch.nn import Conv2d, BatchNorm2d, ReLU
from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.quantization.QConfig import default_qat_qconfig
from torch.nn import Parameter
from common_utils import TestCase, run_tests
from hypothesis import given
from hypothesis import strategies as st
from functools import reduce
class IntrinsicQATModuleTest(TestCase):
@given(batch_size=st.integers(1, 3),
input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
height=st.integers(10, 16),
width=st.integers(7, 14),
output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3, 0.01, 0.1]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans())
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
conv_op = Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
False, # No bias
padding_mode
).to(dtype=torch.float)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.float)
relu_op = ReLU()
cls = ConvBnReLU2d if use_relu else ConvBn2d
qat_op = cls(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
padding_mode,
eps,
momentum,
freeze_bn,
default_qat_qconfig.activation,
default_qat_qconfig.weight
).to(dtype=torch.float).disable_fake_quant()
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.float)
input.requires_grad_()
conv_op.weight = Parameter(qat_op.weight)
bn_op.running_mean = qat_op.running_mean
bn_op.running_var = qat_op.running_var
bn_op.weight = qat_op.gamma
bn_op.bias = qat_op.beta
def compose(functions):
# functions are reversed for natural reading order
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x):
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
result_ref = ref_op(input)
result_actual = qat_op(input)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.float)
result_actual.backward(dout, retain_graph=True)
grad_ref = input.grad.cpu()
result_actual.backward(dout)
grad_actual = input.grad.cpu()
self.assertEqual(grad_ref, grad_actual)
if __name__ == '__main__':
run_tests()
from __future__ import absolute_import, division, print_function, unicode_literals
from .linear_relu import LinearReLU
from .conv_relu import ConvReLU2d
from .conv_fused import ConvBn2d, ConvBnReLU2d, ConvReLU2d
__all__ = [
'LinearReLU',
'ConvReLU2d',
'ConvBn2d',
'ConvBnReLU2d'
]
此差异已折叠。
from __future__ import absolute_import, division, print_function, unicode_literals
from torch.nn.qat import Conv2d as QATConv2d
from torch.nn._intrinsic import ConvReLU2d as NNConvReLU2d
from torch.quantization.QConfig import default_qat_qconfig
import torch.nn.functional as F
class ConvReLU2d(QATConv2d):
r"""
A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
FakeQuantize modules for both output activation and weight for
quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv2d`.
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
observer: fake quant module for output activation, it's called observer
to align with post training flow
weight_fake_quant: fake quant module for weight
"""
__FLOAT_MODULE__ = NNConvReLU2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
activation_fake_quant=default_qat_qconfig.activation,
weight_fake_quant=default_qat_qconfig.weight):
super(ConvReLU2d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode)
self.observer = activation_fake_quant()
self.weight_fake_quant = weight_fake_quant()
def forward(self, input):
return self.observer(F.relu(conv2d_forward(input, self.padding_mode,
self.padding, self.weight_fake_quant(self.weight),
self.bias, self.stride, self.dilation, self.groups),
True))
from __future__ import absolute_import, division, print_function, unicode_literals
from torch.nn.qat import Linear as QATLinear
from torch.nn._intrinsic import LinearReLU2d as NNLinearReLU2d
from torch.nn._intrinsic import LinearReLU as NNLinearReLU
from torch.quantization.QConfig import default_qat_qconfig
import torch.nn.functional as F
......@@ -28,7 +28,7 @@ class LinearReLU(QATLinear):
>>> print(output.size())
torch.Size([128, 30])
"""
__FLOAT_MODULE__ = NNLinearReLU2d
__FLOAT_MODULE = NNLinearReLU
def __init__(self, in_features, out_features, bias=True,
activation_fake_quant=default_qat_qconfig.activation,
......
from __future__ import absolute_import, division, print_function, unicode_literals
from torch.nn import Conv2d as NNConv2d
from torch.quantization.QConfig import default_qat_qconfig
class Conv2d(NNConv2d):
r"""
......@@ -20,13 +19,13 @@ class Conv2d(NNConv2d):
weight_fake_quant: fake quant module for weight
"""
__FLOAT_MODULE__ = NNConv2d
__FLOAT_MODULE = NNConv2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
activation_fake_quant=default_qat_qconfig.activation,
weight_fake_quant=default_qat_qconfig.weight):
activation_fake_quant=None,
weight_fake_quant=None):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias, padding_mode=padding_mode)
......@@ -44,8 +43,8 @@ class Conv2d(NNConv2d):
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls.__FLOAT_MODULE__, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls.__FLOAT_MODULE__.__name__
assert type(mod) == cls.__FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls.__FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
......
from __future__ import absolute_import, division, print_function, unicode_literals
from ...modules.linear import Linear as NNLinear
from torch.quantization.QConfig import default_qat_qconfig
import torch.nn.functional as F
class Linear(NNLinear):
......@@ -20,12 +19,11 @@ class Linear(NNLinear):
to align with post training flow
weight: fake quant module for weight
"""
__constants__ = ['bias', 'in_features', 'out_features']
__FLOAT_MODULE__ = NNLinear
__FLOAT_MODULE = NNLinear
def __init__(self, in_features, out_features, bias=True,
activation_fake_quant=default_qat_qconfig.activation,
weight_fake_quant=default_qat_qconfig.weight):
activation_fake_quant=None,
weight_fake_quant=None):
super(Linear, self).__init__(in_features, out_features, bias)
self.observer = activation_fake_quant()
self.weight_fake_quant = weight_fake_quant()
......@@ -40,8 +38,8 @@ class Linear(NNLinear):
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls.__FLOAT_MODULE__, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls.__FLOAT_MODULE__.__name__
assert type(mod) == cls.__FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls.__FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must has valid qconfig'
......
......@@ -21,6 +21,10 @@ def fuse_conv_bn(conv, bn):
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
assert conv.bias is None, 'Only support fusing Conv2d that does not have bias'
assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
return torch.nn._intrinsic.ConvBn2d(conv, bn)
else:
return torch.nn.utils.fuse_conv_bn_eval(conv, bn)
......@@ -42,6 +46,7 @@ def fuse_conv_bn_relu(conv, bn, relu):
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
assert not relu.inplace, 'We only support fusion of non-inplace ReLU.'
return torch_fused.ConvBnReLU2d(conv, bn, relu)
else:
return torch_fused.ConvReLU2d(
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册