Skip to content

Commit

Permalink
[Auto Paralle]Add reshard cost and update estimator (#45118)
Browse files Browse the repository at this point in the history
* update reshard cost and cost estimator

* add unittest

* add dropout cost

* fix import error

* fix reshard code style error

* improve unittest coverage
  • Loading branch information
Caozhou1995 committed Aug 16, 2022
1 parent 933db9d commit 6a15d40
Show file tree
Hide file tree
Showing 12 changed files with 659 additions and 26 deletions.
19 changes: 19 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Expand Up @@ -148,6 +148,25 @@ def calc_time(self):
return 0


@register_op_cost
class DropoutOpCost(CompOpCost):
OP_TYPE = "dropout"

def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add"
Expand Down
382 changes: 362 additions & 20 deletions python/paddle/distributed/auto_parallel/cost/estimate_cost.py

Large diffs are not rendered by default.

Expand Up @@ -34,7 +34,8 @@
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost


class DistributedEmbedding(DistributedOperatorImplContainer):
Expand Down
Expand Up @@ -32,7 +32,7 @@
from ..cost import FillConstantBatchSizeLikeOpCost
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import AllreduceSumOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost


class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
Expand Down
Expand Up @@ -39,8 +39,9 @@
from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost, IdentityOpCost, AllreduceSumOpCost
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost


def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
Expand Down
Expand Up @@ -24,11 +24,12 @@
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, _g_op_cost_factory
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import SoftmaxOpCost, SoftmaxGradOpCost
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost


class DistributedSoftmax(DistributedOperatorImplContainer):
Expand Down
Expand Up @@ -24,10 +24,11 @@
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, Transpose2OpCost, Transpose2GradOpCost
from ..cost import Transpose2OpCost, Transpose2GradOpCost
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost


class DistributedTranspose2(DistributedOperatorImplContainer):
Expand Down
206 changes: 206 additions & 0 deletions python/paddle/distributed/auto_parallel/reshard.py
Expand Up @@ -2065,3 +2065,209 @@ def reshard(self):

# reset some variable when remove operation ended
Resharder.while_block_info = {}

def get_cost(self, op, tensor, cluster):
# NOTE: The program should be the serial_program which is not been parted
global _g_special_ops
not_supported_op_type = _g_special_ops + ["while"]
reshard_op_cost = None
if op.type in not_supported_op_type:
return reshard_op_cost
else:
tensor_name = tensor.name
if tensor_name == "lod_tensor_blocking_queue_0":
return reshard_op_cost
else:
dist_tensor = self.dist_context.get_dist_tensor_for_program(
tensor)
# simplified processing: ignore union process mesh and output reshard
dist_op = self.dist_context.get_dist_op_for_program(op)
dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
tensor.name)
process_mesh = dist_op.dist_attr.process_mesh
dist_attr = [process_mesh, dims_mapping]
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_attr):
if tensor_name not in self._has_resharded:
self._has_resharded[tensor_name] = [dist_op]
else:
for item in self._has_resharded[tensor_name]:
item_dist_attr = item.dist_attr
item_dims_mapping = item_dist_attr.get_input_dims_mapping(
tensor_name)
item_process_mesh = item_dist_attr.process_mesh
if dims_mapping == item_dims_mapping and item_process_mesh == process_mesh:
return reshard_op_cost
self._has_resharded[tensor_name].append(dist_op)

reshard_op_desc = self.find_op_desc_seq(dist_tensor,
dist_attr,
serial=True)
dtype = dist_tensor.serial_tensor.dtype
reshard_op_cost = self.parse_op_desc_for_cost(
reshard_op_desc, dtype, cluster)

return reshard_op_cost

def _concat_partitions_for_cost(self, partition_tensor_list,
partition_index, dtype, rank_id,
local_rank_comp_cost, cluster):
if not partition_tensor_list:
partition_tensor_list.append(partition_index)
else:
i = 0
has_concat = False
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i], partition_index)
if concat_axis != -1:
has_concat = True
concat_desc = {}
concat_desc["op"] = "concat"
concat_desc["attrs"] = {"axis": concat_axis}
if first_order == 0:
concat_desc["inputs"] = {
"X": [(dtype, partition_tensor_list[i]),
(dtype, partition_index)]
}
else:
concat_desc["inputs"] = {
"X": [(dtype, partition_index),
(dtype, partition_tensor_list[i])]
}
partition_tensor_list.pop(i)
if rank_id not in local_rank_comp_cost:
local_rank_comp_cost[rank_id] = []
local_rank_comp_cost[rank_id].append(
ConcatOpCost(op_desc=concat_desc, cluster=cluster))
self._concat_partitions_for_cost(partition_tensor_list,
new_partition, dtype,
rank_id,
local_rank_comp_cost,
cluster)
break
i += 1
if not has_concat:
partition_tensor_list.append(partition_index)

def parse_op_desc_for_cost(self, reshard_op_desc, dtype, cluster):

def _get_idx(comm_ranks, group_ranks):
res, is_the_same = None, False
idx = 0
while idx < len(comm_ranks):
if comm_ranks[idx] == set(group_ranks):
is_the_same = True

for rank in group_ranks:
if rank in comm_ranks[idx]:
res = idx
comm_ranks[idx].add(rank)
if res is None:
idx += 1
else:
break
return res, is_the_same

comm_context = CommContext(cluster)
# run communication op before computation op
# TODO: Communication cost is not calculated when the var has been transfered by the same group in the past
comm_costs = []
comm_ranks = []
local_rank_comp_cost = {}
for key in reshard_op_desc:
partition_tensor_list = []
op_desc_list = reshard_op_desc[key]
for op_desc in op_desc_list:
if isinstance(op_desc, SendOpDesc):
group_ranks = [key, op_desc.dst]
shape = op_desc.shape
send_desc = build_comm_desc("send_v2", group_ranks, dtype,
shape)
idx, is_the_same = _get_idx(comm_ranks, group_ranks)
if idx is None:
comm_costs.append([
(group_ranks,
SendOpCost(op_desc=send_desc,
comm_context=comm_context))
])
comm_ranks.append(set(group_ranks))
else:
if not is_the_same:
comm_costs[idx].append(
(group_ranks,
SendOpCost(op_desc=send_desc,
comm_context=comm_context)))
elif isinstance(op_desc, AllGatherOpDesc):
# NOTE: fill_const and other unnecessary op is not calculated because those cost is very small
group_ranks = op_desc.group
shape = op_desc.shape
allgather_desc = build_comm_desc("c_allgather", group_ranks,
dtype, shape)
split_inputs_shape = []
for idx, dim in enumerate(shape):
if idx == 0:
split_inputs_shape.append(dim * len(group_ranks))
else:
split_inputs_shape.append(dim)
idx, is_the_same = _get_idx(comm_ranks, group_ranks)
if idx is None:
comm_costs.append([
(group_ranks,
AllgatherOpCost(op_desc=allgather_desc,
comm_context=comm_context))
])
comm_ranks.append(set(group_ranks))
else:
if not is_the_same:
comm_costs[idx].append(
(group_ranks,
AllgatherOpCost(op_desc=allgather_desc,
comm_context=comm_context)))
# calc the split op cost
if key not in local_rank_comp_cost:
local_rank_comp_cost[key] = []
split_desc = {}
split_desc["op"] = "split"
split_desc["inputs"] = {
"inputs": [(dtype, split_inputs_shape)]
}
split_desc["attrs"] = {"num": len(group_ranks), "axis": 0}
local_rank_comp_cost[key].append(
SplitOpCost(op_desc=split_desc, cluster=cluster))
elif isinstance(op_desc, ConcatOpDesc):
partition_index_list = op_desc._partition_index_list
for idx, partion_idex in enumerate(partition_index_list):
self._concat_partitions_for_cost(
partition_tensor_list, partion_idex, dtype, key,
local_rank_comp_cost, cluster)

elif isinstance(op_desc, SliceOpDesc):
if key not in local_rank_comp_cost:
local_rank_comp_cost[key] = []
assert len(
partition_tensor_list) == 1 or not partition_tensor_list
to_slice_tensor_shape = []
if len(partition_tensor_list) == 1:
for item in partition_tensor_list[0]:
to_slice_tensor_shape.append(item[1] - item[0])
else:
to_slice_tensor_shape = op_desc.shape
slice_desc = {}
slice_desc["op"] = "slice"
infer_flags = list(1 for i in range(len(op_desc.axes)))
slice_desc["attrs"] = {
"axes": op_desc.axes,
"starts": op_desc.starts,
"ends": op_desc.ends,
"infer_flags": infer_flags
}
slice_desc["inputs"] = {
"Input": [(dtype, to_slice_tensor_shape)]
}
local_rank_comp_cost[key].append(
SliceOpCost(op_desc=slice_desc, cluster=cluster))

res = (comm_costs, local_rank_comp_cost)

return res
Expand Up @@ -15,6 +15,9 @@
import paddle
import paddle.static as static
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context


def train():
Expand All @@ -39,6 +42,30 @@ def train():
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)

# add cost estimator
dist_context = get_default_distributed_context()
cluster = Cluster()
for op in train_program.global_block().ops:
dist_op = dist_context.get_dist_op_for_program(op)
for var_name in op.input_arg_names:
dims_mapping = dist_op.dist_attr.get_input_dims_mapping(var_name)
if dims_mapping is None:
dist_op.dist_attr.set_input_dims_mapping(
var_name, [
-1 for i in range(
len(train_program.global_block().vars[var_name].
shape))
])
cluster.gen_default_config_cluster(device_count=2)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
assert global_cost.time > 0
assert max_memory > 0

places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
exe = paddle.static.Executor(places[0])
Expand Down
Expand Up @@ -19,6 +19,7 @@

import paddle
import paddle.distributed.auto_parallel.cost as cost_model

from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_op
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_str_for_predict
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling
Expand Down
Expand Up @@ -29,6 +29,8 @@
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.cluster import Cluster

paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
Expand Down Expand Up @@ -196,6 +198,21 @@ def test_mlp_dpmppp(self):
rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)

# test estimator
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=8)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context)
assert global_cost.time > 0
assert max_memory > 0

resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
Expand Down

0 comments on commit 6a15d40

Please sign in to comment.