Skip to content

Commit

Permalink
Update on "Prevent module full_backward_hook from erroring in double …
Browse files Browse the repository at this point in the history
…backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Nov 13, 2022
2 parents 33b0406 + d65ad9c commit eb2dfbd
Show file tree
Hide file tree
Showing 143 changed files with 3,075 additions and 3,541 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
@@ -1 +1 @@
ffd5a567eb90abf6b5555063da434d3c130d540f
deba056203d009fec6b58afb9fa211f6ee3328c8
11 changes: 5 additions & 6 deletions aten/src/ATen/PythonTorchFunctionTLS.cpp
Expand Up @@ -26,12 +26,12 @@ int64_t PythonTorchFunctionTLS::stack_len() {
return pythonTorchFunctionState.stack_.size();
}

void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) {
pythonTorchFunctionState.disabled_state_ = disabled_state;
void PythonTorchFunctionTLS::set_disabled(bool disabled) {
pythonTorchFunctionState.disabled_ = disabled;
}

TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() {
return pythonTorchFunctionState.disabled_state_;
bool PythonTorchFunctionTLS::is_disabled() {
return pythonTorchFunctionState.disabled_;
}

void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) {
Expand All @@ -43,8 +43,7 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
}

bool torch_function_mode_enabled() {
return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
PythonTorchFunctionTLS::stack_len() > 0;
return PythonTorchFunctionTLS::stack_len() > 0;
}

} // namespace impl
Expand Down
12 changes: 5 additions & 7 deletions aten/src/ATen/PythonTorchFunctionTLS.h
Expand Up @@ -6,11 +6,9 @@
namespace at {
namespace impl {

enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };

struct TORCH_API PythonTorchFunctionTLS {
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
static TorchFunctionDisabledState get_disabled_state();
static void set_disabled(bool);
static bool is_disabled();

static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
Expand All @@ -22,11 +20,11 @@ struct TORCH_API PythonTorchFunctionTLS {

private:
// The mode TLS is split into
// - disabled_state, which says which part of torch function are disabled
// - disabled_, which says whether or not to disable all torch function
// modes
// - stack_, which is a vector of modes representing the stack of user
// defined modes
TorchFunctionDisabledState disabled_state_ =
TorchFunctionDisabledState::ENABLED;
bool disabled_;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};

Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/core/PythonFallbackKernel.cpp
Expand Up @@ -74,10 +74,13 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
(*interpreter)->dispatch(op, stack);
return;
}
} else if (ivalue.isTensorList() || (ivalue.isOptionalTensorList() && !ivalue.isNone())) {
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
// NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
// is not a thing)
for (const auto& nv : ivalue.toListRef()) {
if (nv.isNone()) {
continue;
}
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
(*interpreter)->dispatch(op, stack);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -239,6 +239,7 @@ namespace c10 {
_(onnx, LSTM) \
_(onnx, MatMul) \
_(onnx, Min) \
_(onnx, Max) \
_(onnx, Mul) \
_(onnx, Pow) \
_(onnx, RNN) \
Expand Down
17 changes: 13 additions & 4 deletions aten/src/ATen/cuda/Atomic.cuh
Expand Up @@ -6,6 +6,10 @@

#include <ATen/NumericUtils.h>

#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#include <cuda_bf16.h>
#endif

template <typename T>
struct AtomicFPOp;

Expand Down Expand Up @@ -219,10 +223,15 @@ static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val)
}

static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
return bsum + val;
});
#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
return AtomicFPOp<at::BFloat16>()(address, val,
[](at::BFloat16 bsum, at::BFloat16 val) {
return bsum + val;
});
#else
__nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
return *reinterpret_cast<c10::BFloat16*>(&r);
#endif
}

#if defined(CUDA_VERSION) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Expand Up @@ -63,7 +63,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
OP_DECOMPOSE2(bitwise_or, Scalar);
OP_DECOMPOSE2(bitwise_xor, Scalar);
OP_DECOMPOSE(broadcast_tensors);
OP_DECOMPOSE(broadcast_to);
m.impl("broadcast_to", native::broadcast_to_symint);
OP_DECOMPOSE(cartesian_prod);
OP_DECOMPOSE(cdist);
OP_DECOMPOSE(clip);
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/Copy.cpp
Expand Up @@ -220,6 +220,18 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return at::metal::metal_copy_(self, src);
}

// Exit early if self and src are views of the same data
const bool is_same_data = (
self.is_alias_of(src) &&
self.storage_offset() == src.storage_offset() &&
self.strides().equals(src.strides()) &&
self.sizes().equals(src.sizes()) &&
self.scalar_type() == src.scalar_type()
);
if (is_same_data) {
return self;
}


auto iter = TensorIteratorConfig()
.add_output(self)
Expand Down
17 changes: 12 additions & 5 deletions aten/src/ATen/native/TensorShape.cpp
Expand Up @@ -537,8 +537,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced);
}

Tensor broadcast_to(const Tensor& self, IntArrayRef size) {
return self.expand(size);
Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) {
return self.expand_symint(size);
}

std::vector<Tensor> broadcast_tensors(TensorList tensors) {
Expand Down Expand Up @@ -1196,6 +1196,8 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
}

// Should just use narrow_copy_out, but this API is used internally at Meta:
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
auto output = at::empty_like(self);
return narrow_copy_dense_cpu_out(self, dim, start, length, output);
Expand All @@ -1205,9 +1207,10 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_
int64_t allDim = self.dim();
int64_t end = start+length;
TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
TORCH_CHECK(dim >= 0 && dim < allDim,
"Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, ".");
TORCH_CHECK(start >= 0 && length >= 0 && end <= self.size(dim),
TORCH_CHECK(start >= 0 && end <= self.size(dim),
"Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").")
Tensor indices = self._indices();
int64_t sparse_dim = self.sparse_dim();
Expand Down Expand Up @@ -1235,6 +1238,8 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_
return newTensor._coalesced_(self.is_coalesced());
}

// Should just use narrow_copy_out, but this API is used internally at Meta:
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
Tensor& narrow_copy_dense_cpu_out(
const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output
) {
Expand Down Expand Up @@ -1318,22 +1323,24 @@ Tensor& narrow_copy_dense_cpu_out(

Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
}
TORCH_CHECK(length >= 0 && start <= cur_size - length,
TORCH_CHECK(start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice(self, dim, start, start + length, 1);
}

Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
auto cur_size = self.sym_size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
}
TORCH_CHECK(length >= 0 && start <= cur_size - length,
TORCH_CHECK(start <= cur_size - length,
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
return at::slice_symint(self, dim, start, start + length, 1);
}
Expand Down
48 changes: 47 additions & 1 deletion aten/src/ATen/native/cuda/KernelUtils.cuh
@@ -1,6 +1,10 @@
#pragma once
#include <ATen/cuda/Atomic.cuh>

#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#include <cuda_bf16.h>
#endif

namespace at {
namespace native {

Expand Down Expand Up @@ -66,7 +70,49 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
template <
typename scalar_t,
typename index_t,
typename std::enable_if<!std::is_same<c10::Half, scalar_t>::value>::type* =
typename std::enable_if<std::is_same<c10::BFloat16, scalar_t>::value>::type* =
nullptr>
__device__ __forceinline__ void fastSpecializedAtomicAdd(
scalar_t* tensor,
index_t index,
const index_t numel,
scalar_t value) {
#if ( \
(defined(USE_ROCM)) || \
(defined(CUDA_VERSION) && (CUDA_VERSION < 11000)) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
gpuAtomicAddNoReturn(
reinterpret_cast<at::BFloat16*>(tensor) + index,
static_cast<at::BFloat16>(value));
#else
// Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
__nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__nv_bfloat162) == 0);

if (low_byte && index < (numel - 1)) {
__nv_bfloat162 value2;
value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
value2.y = __int2bfloat16_rz(0);
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);

} else if (!low_byte && index > 0) {
__nv_bfloat162 value2;
value2.x = __int2bfloat16_rz(0);
value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);

} else {
atomicAdd(
reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
}
#endif
}


template <
typename scalar_t,
typename index_t,
typename std::enable_if<!std::is_same<c10::Half, scalar_t>::value && !std::is_same<c10::BFloat16, scalar_t>::value >::type* =
nullptr>
__device__ __forceinline__ void fastSpecializedAtomicAdd(
scalar_t* tensor,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cudnn/RNN.cpp
Expand Up @@ -70,7 +70,7 @@ Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_see
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);

AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support");
}
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -1195,8 +1195,10 @@
device_check: NoCheck
device_guard: False

- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
variants: function, method
dispatch:
CompositeImplicitAutograd: broadcast_to_symint

- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
variants: function
Expand Down
1 change: 1 addition & 0 deletions benchmarks/dynamo/common.py
Expand Up @@ -156,6 +156,7 @@
"hrnet_w18", # accuracy
"lcnet_0500", # accuracy
"levit_128", # levit_128
"poolformer_m36",
"rexnet_100", # accuracy
"swin_base_patch4_window7_224",
"twins_pcpvt_base", # time out
Expand Down
20 changes: 11 additions & 9 deletions benchmarks/dynamo/dist_util.py
Expand Up @@ -20,6 +20,9 @@
except ImportError:
from torchbench import setup_torchbench_cwd

from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
from transformers.models.t5.modeling_t5 import T5Block


def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
Expand Down Expand Up @@ -122,26 +125,25 @@ def check_fn(submodule):
)


# from transformers.models.t5.modeling_t5 import T5Block

MODEL_FSDP_WRAP = {
ToyModel: (MyModule,)
# TODO T5: (T5Block,)
"toy_model": (MyModule,),
"hf_Bert": (BertLayer, BertLMPredictionHead),
"hf_T5": (T5Block,),
}


def apply_fsdp(model, use_checkpointing=False, use_wrap_policy=True):
blocks = MODEL_FSDP_WRAP[model.__class__]

def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
wrap_policy = None
blocks = MODEL_FSDP_WRAP[
"toy_model" if model.__class__ is ToyModel else args.torchbench_model
]
if use_wrap_policy:
# transformer policy is really a generic policy that wraps modules of specified classes
wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=blocks
)

model = FSDP(model, auto_wrap_policy=wrap_policy)
model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
if use_checkpointing:
fsdp_checkpointing_base(model, blocks)

return model
5 changes: 4 additions & 1 deletion benchmarks/dynamo/distributed.py
Expand Up @@ -50,6 +50,7 @@ def move_tensor(maybe_tensor):

if args.fsdp:
model = apply_fsdp(
args,
model,
use_checkpointing=args.fsdp_checkpoint,
use_wrap_policy=args.fsdp_wrap,
Expand Down Expand Up @@ -160,7 +161,9 @@ def experiment(fn, key, world_size, results):
)
args = parser.parse_args()

model_name = "ToyModel" if args.toy_model else args.torchbench_model
model_name = args.torchbench_model
if args.toy_model:
model_name = "ToyModel"
model, inputs = get_model(args)

fn = partial(run_model, args, model, inputs)
Expand Down
2 changes: 2 additions & 0 deletions c2_defs.bzl
Expand Up @@ -166,6 +166,7 @@ def get_c2_fbandroid_xplat_compiler_flags():
# T95767731 -- remove this once all builds are on at least llvm-13
"-Wno-unknown-warning-option",
"-Wno-unused-but-set-variable",
"-DHAVE_MMAP",
]

if get_c2_strip_glog():
Expand Down Expand Up @@ -392,6 +393,7 @@ def c2_cxx_library(**kwargs):
args = get_c2_default_cxx_args()
args.update(kwargs)
args.setdefault("platforms", (ANDROID, APPLE, CXX, WINDOWS))

fb_xplat_cxx_library(
labels = [
"supermodule:android/default/caffe2",
Expand Down

0 comments on commit eb2dfbd

Please sign in to comment.