Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fold kwarg for arrow.get #1139

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions arrow/api.py
Expand Up @@ -25,6 +25,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = 0,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -35,6 +36,7 @@ def get(
*args: int,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = 0,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -56,6 +58,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = 0,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -68,6 +71,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = 0,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -80,6 +84,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = 0,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand Down
42 changes: 28 additions & 14 deletions arrow/arrow.py
Expand Up @@ -159,7 +159,8 @@ def __init__(
second: int = 0,
microsecond: int = 0,
tzinfo: Optional[TZ_EXPR] = None,
**kwargs: Any,
*,
fold: int = 0,
) -> None:
if tzinfo is None:
tzinfo = dateutil_tz.tzutc()
Expand All @@ -174,8 +175,6 @@ def __init__(
elif isinstance(tzinfo, str):
tzinfo = parser.TzinfoParser.parse(tzinfo)

fold = kwargs.get("fold", 0)

self._datetime = dt_datetime(
year, month, day, hour, minute, second, microsecond, tzinfo, fold=fold
)
Expand Down Expand Up @@ -210,7 +209,7 @@ def now(cls, tzinfo: Optional[dt_tzinfo] = None) -> "Arrow":
dt.second,
dt.microsecond,
dt.tzinfo,
fold=getattr(dt, "fold", 0),
fold=dt.fold,
)

@classmethod
Expand All @@ -236,7 +235,7 @@ def utcnow(cls) -> "Arrow":
dt.second,
dt.microsecond,
dt.tzinfo,
fold=getattr(dt, "fold", 0),
fold=dt.fold,
)

@classmethod
Expand Down Expand Up @@ -273,7 +272,7 @@ def fromtimestamp(
dt.second,
dt.microsecond,
dt.tzinfo,
fold=getattr(dt, "fold", 0),
fold=dt.fold,
)

@classmethod
Expand All @@ -299,11 +298,16 @@ def utcfromtimestamp(cls, timestamp: Union[int, float, str]) -> "Arrow":
dt.second,
dt.microsecond,
dateutil_tz.tzutc(),
fold=getattr(dt, "fold", 0),
fold=dt.fold,
)

@classmethod
def fromdatetime(cls, dt: dt_datetime, tzinfo: Optional[TZ_EXPR] = None) -> "Arrow":
def fromdatetime(
cls,
dt: dt_datetime,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = None,
) -> "Arrow":
"""Constructs an :class:`Arrow <arrow.arrow.Arrow>` object from a ``datetime`` and
optional replacement timezone.

Expand All @@ -326,6 +330,9 @@ def fromdatetime(cls, dt: dt_datetime, tzinfo: Optional[TZ_EXPR] = None) -> "Arr
else:
tzinfo = dt.tzinfo

if fold is None:
fold = dt.fold

return cls(
dt.year,
dt.month,
Expand All @@ -335,7 +342,7 @@ def fromdatetime(cls, dt: dt_datetime, tzinfo: Optional[TZ_EXPR] = None) -> "Arr
dt.second,
dt.microsecond,
tzinfo,
fold=getattr(dt, "fold", 0),
fold=fold,
)

@classmethod
Expand All @@ -355,10 +362,14 @@ def fromdate(cls, date: date, tzinfo: Optional[TZ_EXPR] = None) -> "Arrow":

@classmethod
def strptime(
cls, date_str: str, fmt: str, tzinfo: Optional[TZ_EXPR] = None
cls,
date_str: str,
fmt: str,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = None,
) -> "Arrow":
"""Constructs an :class:`Arrow <arrow.arrow.Arrow>` object from a date string and format,
in the style of ``datetime.strptime``. Optionally replaces the parsed timezone.
in the style of ``datetime.strptime``. Optionally replaces the parsed timezone and fold.

:param date_str: the date string.
:param fmt: the format string using datetime format codes.
Expand All @@ -376,6 +387,9 @@ def strptime(
if tzinfo is None:
tzinfo = dt.tzinfo

if fold is None:
fold = dt.fold

return cls(
dt.year,
dt.month,
Expand All @@ -385,7 +399,7 @@ def strptime(
dt.second,
dt.microsecond,
tzinfo,
fold=getattr(dt, "fold", 0),
fold=fold,
)

@classmethod
Expand Down Expand Up @@ -413,7 +427,7 @@ def fromordinal(cls, ordinal: int) -> "Arrow":
dt.second,
dt.microsecond,
dt.tzinfo,
fold=getattr(dt, "fold", 0),
fold=dt.fold,
)

# factories: ranges and spans
Expand Down Expand Up @@ -1087,7 +1101,7 @@ def to(self, tz: TZ_EXPR) -> "Arrow":
dt.second,
dt.microsecond,
dt.tzinfo,
fold=getattr(dt, "fold", 0),
fold=dt.fold,
)

# string output and formatting
Expand Down
32 changes: 23 additions & 9 deletions arrow/factory.py
Expand Up @@ -40,6 +40,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = None,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -61,6 +62,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = None,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -73,6 +75,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = None,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -85,6 +88,7 @@ def get(
*,
locale: str = DEFAULT_LOCALE,
tzinfo: Optional[TZ_EXPR] = None,
fold: Optional[int] = None,
normalize_whitespace: bool = False,
) -> Arrow:
... # pragma: no cover
Expand All @@ -96,6 +100,11 @@ def get(self, *args: Any, **kwargs: Any) -> Arrow:
:param tzinfo: (optional) a :ref:`timezone expression <tz-expr>` or tzinfo object.
Replaces the timezone unless using an input form that is explicitly UTC or specifies
the timezone in a positional argument. Defaults to UTC.
:param fold: (optional) an ``int`` value of 0 or 1.
Replaces the fold value, used to disambiguate repeated wall times.
Used only when the first argument is an Arrow instance/datetime/datetime string,
or datetime constructor kwargs were provided.

:param normalize_whitespace: (optional) a ``bool`` specifying whether or not to normalize
redundant whitespace (spaces, tabs, and newlines) in a datetime string before parsing.
Defaults to false.
Expand Down Expand Up @@ -196,14 +205,19 @@ def get(self, *args: Any, **kwargs: Any) -> Arrow:
arg_count = len(args)
locale = kwargs.pop("locale", DEFAULT_LOCALE)
tz = kwargs.get("tzinfo", None)
fold = kwargs.get("fold")
normalize_whitespace = kwargs.pop("normalize_whitespace", False)

# if kwargs given, send to constructor unless only tzinfo provided
if len(kwargs) > 1:
# if kwargs given, send to constructor unless only tzinfo and/or fold provided
if len(kwargs) > 2:
arg_count = 3

# either tzinfo or fold kwarg is not provided
elif len(kwargs) == 2 and None in (tz, fold):
arg_count = 3

# tzinfo kwarg is not provided
if len(kwargs) == 1 and tz is None:
# tzinfo and fold kwargs are both not provided
elif len(kwargs) == 1 and tz is fold is None:
arg_count = 3

# () -> now, @ tzinfo or utc
Expand Down Expand Up @@ -235,11 +249,11 @@ def get(self, *args: Any, **kwargs: Any) -> Arrow:

# (Arrow) -> from the object's datetime @ tzinfo
elif isinstance(arg, Arrow):
return self.type.fromdatetime(arg.datetime, tzinfo=tz)
return self.type.fromdatetime(arg.datetime, tzinfo=tz, fold=fold)

# (datetime) -> from datetime @ tzinfo
elif isinstance(arg, datetime):
return self.type.fromdatetime(arg, tzinfo=tz)
return self.type.fromdatetime(arg, tzinfo=tz, fold=fold)

# (date) -> from date @ tzinfo
elif isinstance(arg, date):
Expand All @@ -252,7 +266,7 @@ def get(self, *args: Any, **kwargs: Any) -> Arrow:
# (str) -> parse @ tzinfo
elif isinstance(arg, str):
dt = parser.DateTimeParser(locale).parse_iso(arg, normalize_whitespace)
return self.type.fromdatetime(dt, tzinfo=tz)
return self.type.fromdatetime(dt, tzinfo=tz, fold=fold)

# (struct_time) -> from struct_time
elif isinstance(arg, struct_time):
Expand All @@ -274,7 +288,7 @@ def get(self, *args: Any, **kwargs: Any) -> Arrow:

# (datetime, tzinfo/str) -> fromdatetime @ tzinfo
if isinstance(arg_2, (dt_tzinfo, str)):
return self.type.fromdatetime(arg_1, tzinfo=arg_2)
return self.type.fromdatetime(arg_1, tzinfo=arg_2, fold=fold)
else:
raise TypeError(
f"Cannot parse two arguments of types 'datetime', {type(arg_2)!r}."
Expand All @@ -295,7 +309,7 @@ def get(self, *args: Any, **kwargs: Any) -> Arrow:
dt = parser.DateTimeParser(locale).parse(
args[0], args[1], normalize_whitespace
)
return self.type.fromdatetime(dt, tzinfo=tz)
return self.type.fromdatetime(dt, tzinfo=tz, fold=fold)

else:
raise TypeError(
Expand Down
11 changes: 8 additions & 3 deletions tests/test_arrow.py
Expand Up @@ -74,9 +74,6 @@ def test_init_with_fold(self):
before = arrow.Arrow(2017, 10, 29, 2, 0, tzinfo="Europe/Stockholm")
after = arrow.Arrow(2017, 10, 29, 2, 0, tzinfo="Europe/Stockholm", fold=1)

assert hasattr(before, "fold")
assert hasattr(after, "fold")
Comment on lines -77 to -78
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically checks for presence of the Arrow.fold property, seems unnecessary.


# PEP-495 requires the comparisons below to be true
assert before == after
assert before.utcoffset() != after.utcoffset()
Expand Down Expand Up @@ -183,6 +180,14 @@ def test_strptime(self):
2013, 2, 3, 12, 30, 45, tzinfo=tz.gettz("Europe/Paris")
)

def test_strptime_with_fold(self):

formatted = datetime(2013, 2, 3, 12, 30, 45).strftime("%Y-%m-%d %H:%M:%S")

result = arrow.Arrow.strptime(formatted, "%Y-%m-%d %H:%M:%S", fold=1)
assert result._datetime == datetime(2013, 2, 3, 12, 30, 45, tzinfo=tz.tzutc())
assert result.fold == 1

def test_fromordinal(self):

timestamp = 1607066909.937968
Expand Down
19 changes: 19 additions & 0 deletions tests/test_factory.py
Expand Up @@ -97,12 +97,31 @@ def test_one_arg_arrow(self):

assert arw == result

def test_one_arg_arrow_with_fold(self):

arw = self.factory.utcnow()
result = self.factory.get(arw, fold=1)

# fold is ignored for comparison
assert arw.fold == 0
assert result.fold == 1
assert arw == result

def test_one_arg_datetime(self):

dt = datetime.utcnow().replace(tzinfo=tz.tzutc())

assert self.factory.get(dt) == dt

def test_one_arg_datetime_with_fold(self):

dt = datetime.utcnow().replace(tzinfo=tz.tzutc())
result = self.factory.get(dt, fold=1)

assert dt.fold == 0
assert result.fold == 1
assert result == dt

def test_one_arg_date(self):

d = date.today()
Expand Down