Skip to content

Commit

Permalink
Avoid calling logging.basicConfig (pytorch#86959)
Browse files Browse the repository at this point in the history
  • Loading branch information
SherlockNoMad authored and atalman committed Oct 21, 2022
1 parent 55c76ba commit 43a6cc1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions torch/fx/passes/backends/nvfuser.py
Expand Up @@ -17,8 +17,8 @@

import logging

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

def aten_to_dtype(self, dtype: torch.dtype, **kwargs):
if len(kwargs) > 0 or not dtype:
Expand Down Expand Up @@ -244,26 +244,26 @@ def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs)
# "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent

if graph_module in self.prim_decomp_cache:
logging.debug("prim_decomp_cache hit!")
logger.debug("prim_decomp_cache hit!")
prim_module = self.prim_decomp_cache[graph_module]
else:
prim_graph = torch.fx.Graph()
DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run(*args, **kwargs)
prim_module = torch.fx.GraphModule(graph_module, prim_graph)
self.prim_decomp_cache[graph_module] = prim_module

logging.debug("Lower to prims graph: ", prim_module.code)
logger.debug("Lower to prims graph: ", prim_module.code)

# invokes trace executor for running the prim graph
return execute(prim_module, *args, executor="nvfuser")

def compile(self, graph_module: GraphModule) -> GraphModule:
# entry function for nvFuser backend
logging.debug("Compiling graph_module: ", graph_module.code)
logger.debug("Compiling graph_module: ", graph_module.code)

# FX graph based partitioning based on nvfuser supported ops
if graph_module in self.partitioner_cache:
logging.debug("partitioner_cache hit!")
logger.debug("partitioner_cache hit!")
fused_graph_module = self.partitioner_cache[graph_module]
else:
partitioner = CapabilityBasedPartitioner(
Expand Down
10 changes: 5 additions & 5 deletions torch/fx/passes/infra/partitioner.py
Expand Up @@ -11,8 +11,8 @@
import logging
import itertools

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

class Partition:
def __init__(self, id: int = None, nodes: Iterable[Node] = None):
Expand Down Expand Up @@ -120,7 +120,7 @@ def assign(node: Node, id: Optional[int] = None):
else:
partitions_by_id[id].add_node(node)

logging.debug("Proposing partitions...")
logger.debug("Proposing partitions...")

# visit candidates in reversed topological order
for node in reversed(candidates):
Expand Down Expand Up @@ -210,14 +210,14 @@ def assign(node: Node, id: Optional[int] = None):
for id in partitions_to_remove:
del partitions_by_id[id]

logging.debug("Partitions proposed:")
logger.debug("Partitions proposed:")
for id, partition in partitions_by_id.items():
logging.debug(f"partition #{id}", [node.name for node in partition.nodes])
logger.debug(f"partition #{id}", [node.name for node in partition.nodes])

return list(partitions_by_id.values())

def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
logging.debug("Fusing partitions...")
logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])

Expand Down

0 comments on commit 43a6cc1

Please sign in to comment.