Skip to content

Commit

Permalink
ENH: rewrite backports from Python std lib in a fashion that can be a…
Browse files Browse the repository at this point in the history
…utomatically cleaned up with pyupgrade
  • Loading branch information
neutrinoceros committed Sep 27, 2021
1 parent ef7d320 commit f5b5b19
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 35 deletions.
52 changes: 25 additions & 27 deletions yt/geometry/geometry_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@
)


def cached_chunk_property(func):
n = f"_{func.__name__}"

def cached_func(self):
if self._cache and getattr(self, n, None) is not None:
return getattr(self, n)
if self.data_size is None:
tr = self._accumulate_values(n[1:])
else:
tr = func(self)
if self._cache:

setattr(self, n, tr)
return tr

return property(cached_func)


class Index(ParallelAnalysisInterface, abc.ABC):
"""The base index class"""

Expand Down Expand Up @@ -243,26 +261,6 @@ def _chunk(self, dobj, chunking_style, ngz=0, **kwargs):
raise NotImplementedError


def cached_property(func):
# TODO: remove this once minimal supported version of Python reaches 3.8
# and replace with functools.cached
n = f"_{func.__name__}"

def cached_func(self):
if self._cache and getattr(self, n, None) is not None:
return getattr(self, n)
if self.data_size is None:
tr = self._accumulate_values(n[1:])
else:
tr = func(self)
if self._cache:

setattr(self, n, tr)
return tr

return property(cached_func)


class YTDataChunk:
def __init__(
self,
Expand Down Expand Up @@ -298,7 +296,7 @@ def _accumulate_values(self, method):
self.data_size = arrs.shape[0]
return arrs

@cached_property
@cached_chunk_property
def fcoords(self):
if self._fast_index is not None:
ci = self._fast_index.select_fcoords(self.dobj.selector, self.data_size)
Expand All @@ -317,7 +315,7 @@ def fcoords(self):
ind += c.shape[0]
return ci

@cached_property
@cached_chunk_property
def icoords(self):
if self._fast_index is not None:
ci = self._fast_index.select_icoords(self.dobj.selector, self.data_size)
Expand All @@ -334,7 +332,7 @@ def icoords(self):
ind += c.shape[0]
return ci

@cached_property
@cached_chunk_property
def fwidth(self):
if self._fast_index is not None:
ci = self._fast_index.select_fwidth(self.dobj.selector, self.data_size)
Expand All @@ -353,7 +351,7 @@ def fwidth(self):
ind += c.shape[0]
return ci

@cached_property
@cached_chunk_property
def ires(self):
if self._fast_index is not None:
ci = self._fast_index.select_ires(self.dobj.selector, self.data_size)
Expand All @@ -370,12 +368,12 @@ def ires(self):
ind += c.size
return ci

@cached_property
@cached_chunk_property
def tcoords(self):
self.dtcoords
return self._tcoords

@cached_property
@cached_chunk_property
def dtcoords(self):
ct = np.empty(self.data_size, dtype="float64")
cdt = np.empty(self.data_size, dtype="float64")
Expand All @@ -392,7 +390,7 @@ def dtcoords(self):
ind += gt.size
return cdt

@cached_property
@cached_chunk_property
def fcoords_vertex(self):
nodes_per_elem = self.dobj.index.meshes[0].connectivity_indices.shape[1]
dim = self.dobj.ds.dimensionality
Expand Down
8 changes: 1 addition & 7 deletions yt/visualization/plot_window.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from collections import defaultdict
from functools import wraps
from numbers import Number
Expand Down Expand Up @@ -48,18 +49,11 @@
)
from .plot_modifications import callback_registry

import sys # isort: skip

if sys.version_info < (3, 10):
# this function is deprecated in more_itertools
# because it is superseded by the standard library
from more_itertools import zip_equal
else:

def zip_equal(*args):
# FUTURE: when only Python 3.10+ is supported,
# drop this conditional and call the builtin zip
# function directly where due
return zip(*args, strict=True)


Expand Down
23 changes: 22 additions & 1 deletion yt/visualization/volume_rendering/old_camera.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import builtins
import sys
from copy import deepcopy

import numpy as np

from yt.config import ytcfg
from yt.data_objects.api import ImageArray
from yt.funcs import ensure_numpy_array, get_num_threads, get_pbar, is_sequence, mylog
from yt.geometry.geometry_handler import cached_property
from yt.units.yt_array import YTArray
from yt.utilities.amr_kdtree.api import AMRKDTree
from yt.utilities.exceptions import YTNotInsideNotebook
Expand Down Expand Up @@ -35,6 +35,27 @@

from .transfer_functions import ProjectionTransferFunction

if sys.version_info >= (3, 8):
from functools import cached_property
else:

def cached_property(func):
n = f"_{func.__name__}"

def cached_func(self):
if self._cache and getattr(self, n, None) is not None:
return getattr(self, n)
if self.data_size is None:
tr = self._accumulate_values(n[1:])
else:
tr = func(self)
if self._cache:

setattr(self, n, tr)
return tr

return property(cached_func)


def get_corners(le, re):
return np.array(
Expand Down

0 comments on commit f5b5b19

Please sign in to comment.