Skip to content

Commit

Permalink
add sparse SyncBatchNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 committed Jun 14, 2022
1 parent 084f2a9 commit f8c44d2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python/paddle/incubate/sparse/nn/__init__.py
Expand Up @@ -15,14 +15,15 @@
from . import functional

from .layer.activation import ReLU
from .layer.norm import BatchNorm
from .layer.norm import BatchNorm, SyncBatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D

__all__ = [
'ReLU',
'BatchNorm',
'SyncBatchNorm',
'Conv3D',
'SubmConv3D',
'MaxPool3D',
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/incubate/sparse/nn/layer/norm.py
Expand Up @@ -27,6 +27,8 @@

import paddle
import warnings
from paddle.nn.layer.norm import _BatchNormBase
from paddle.framework import no_grad


class BatchNorm(paddle.nn.BatchNorm1D):
Expand Down Expand Up @@ -171,7 +173,7 @@ def __init__(self,
name=None):
super(SyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, None, name)
bias_attr, data_format, name)

def forward(self, x):
out = super(SyncBatchNorm, self).forward(x.values())
Expand Down

0 comments on commit f8c44d2

Please sign in to comment.