From 43a6cc1f5013eba42313caa39a705c3249daa6d6 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Fri, 14 Oct 2022 16:11:15 +0000 Subject: [PATCH] Avoid calling logging.basicConfig (#86959) Fixes https://github.com/pytorch/pytorch/issues/85952 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86959 Approved by: https://github.com/xwang233, https://github.com/davidberard98 --- torch/fx/passes/backends/nvfuser.py | 10 +++++----- torch/fx/passes/infra/partitioner.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torch/fx/passes/backends/nvfuser.py b/torch/fx/passes/backends/nvfuser.py index cbded1de0cde..fdb1dd9a3320 100644 --- a/torch/fx/passes/backends/nvfuser.py +++ b/torch/fx/passes/backends/nvfuser.py @@ -17,8 +17,8 @@ import logging -logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) def aten_to_dtype(self, dtype: torch.dtype, **kwargs): if len(kwargs) > 0 or not dtype: @@ -244,7 +244,7 @@ def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs) # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent if graph_module in self.prim_decomp_cache: - logging.debug("prim_decomp_cache hit!") + logger.debug("prim_decomp_cache hit!") prim_module = self.prim_decomp_cache[graph_module] else: prim_graph = torch.fx.Graph() @@ -252,18 +252,18 @@ def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs) prim_module = torch.fx.GraphModule(graph_module, prim_graph) self.prim_decomp_cache[graph_module] = prim_module - logging.debug("Lower to prims graph: ", prim_module.code) + logger.debug("Lower to prims graph: ", prim_module.code) # invokes trace executor for running the prim graph return execute(prim_module, *args, executor="nvfuser") def compile(self, graph_module: GraphModule) -> GraphModule: # entry function for nvFuser backend - logging.debug("Compiling graph_module: ", graph_module.code) + logger.debug("Compiling graph_module: ", graph_module.code) # FX graph based partitioning based on nvfuser supported ops if graph_module in self.partitioner_cache: - logging.debug("partitioner_cache hit!") + logger.debug("partitioner_cache hit!") fused_graph_module = self.partitioner_cache[graph_module] else: partitioner = CapabilityBasedPartitioner( diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 18a665b88ede..db6dc1bd979d 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -11,8 +11,8 @@ import logging import itertools -logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) class Partition: def __init__(self, id: int = None, nodes: Iterable[Node] = None): @@ -120,7 +120,7 @@ def assign(node: Node, id: Optional[int] = None): else: partitions_by_id[id].add_node(node) - logging.debug("Proposing partitions...") + logger.debug("Proposing partitions...") # visit candidates in reversed topological order for node in reversed(candidates): @@ -210,14 +210,14 @@ def assign(node: Node, id: Optional[int] = None): for id in partitions_to_remove: del partitions_by_id[id] - logging.debug("Partitions proposed:") + logger.debug("Partitions proposed:") for id, partition in partitions_by_id.items(): - logging.debug(f"partition #{id}", [node.name for node in partition.nodes]) + logger.debug(f"partition #{id}", [node.name for node in partition.nodes]) return list(partitions_by_id.values()) def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: - logging.debug("Fusing partitions...") + logger.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])