Skip to content

Commit

Permalink
ENH: correct dataframe check (#255)
Browse files Browse the repository at this point in the history
* ENH: correct dataframe check

* CLN: change assert comment
  • Loading branch information
hongyeehh committed Jun 27, 2021
1 parent 1545840 commit 215c6fc
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 28 deletions.
8 changes: 6 additions & 2 deletions trackintel/analysis/location_identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def location_identifier(spts, method="FREQ", pre_filter=True, **pre_filter_kwarg
--------
>>> ti.analysis.location_identifier(spts, pre_filter=True, method="FREQ")
"""
assert spts.as_staypoints
# assert validity of staypoints
spts.as_staypoints

spts = spts.copy()
if "location_id" not in spts.columns:
raise KeyError(
Expand Down Expand Up @@ -119,7 +121,9 @@ def pre_filter_locations(
>> mask = ti.analysis.pre_filter_locations(spts)
>> spts = spts[mask]
"""
assert spts.as_staypoints
# assert validity of staypoints
spts.as_staypoints

spts = spts.copy()
if isinstance(thresh_loc_time, str):
thresh_loc_time = pd.to_timedelta(thresh_loc_time)
Expand Down
19 changes: 14 additions & 5 deletions trackintel/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def read_positionfixes_csv(*args, columns=None, tz=None, index_col=object(), crs
pfs = gpd.GeoDataFrame(df, geometry="geom")
if crs:
pfs.set_crs(crs, inplace=True)
assert pfs.as_positionfixes

# assert validity of positionfixes
pfs.as_positionfixes
return pfs


Expand Down Expand Up @@ -195,7 +197,9 @@ def read_triplegs_csv(*args, columns=None, tz=None, index_col=object(), crs=None
tpls = gpd.GeoDataFrame(df, geometry="geom")
if crs:
tpls.set_crs(crs, inplace=True)
assert tpls.as_triplegs

# assert validity of triplegs
tpls.as_triplegs
return tpls


Expand Down Expand Up @@ -291,7 +295,9 @@ def read_staypoints_csv(*args, columns=None, tz=None, index_col=object(), crs=No
stps = gpd.GeoDataFrame(df, geometry="geom")
if crs:
stps.set_crs(crs, inplace=True)
assert stps.as_staypoints

# assert validity of staypoints
stps.as_staypoints
return stps


Expand Down Expand Up @@ -376,7 +382,9 @@ def read_locations_csv(*args, columns=None, index_col=object(), crs=None, **kwar
locs = gpd.GeoDataFrame(df, geometry="center")
if crs:
locs.set_crs(crs, inplace=True)
assert locs.as_locations

# assert validity of locations
locs.as_locations
return locs


Expand Down Expand Up @@ -464,7 +472,8 @@ def read_trips_csv(*args, columns=None, tz=None, index_col=object(), **kwargs):
if not pd.api.types.is_datetime64tz_dtype(trips[col]):
trips[col] = _localize_timestamp(dt_series=trips[col], pytz_tzinfo=tz, col_name=col)

assert trips.as_trips
# assert validity of trips
trips.as_trips
return trips


Expand Down
17 changes: 11 additions & 6 deletions trackintel/io/from_geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def read_positionfixes_gpd(gdf, tracked_at="tracked_at", user_id="user_id", geom
if not pd.api.types.is_datetime64tz_dtype(pfs[col]):
pfs[col] = _localize_timestamp(dt_series=pfs[col], pytz_tzinfo=tz, col_name=col)

assert pfs.as_positionfixes
# assert validity of positionfixes
pfs.as_positionfixes
return pfs


Expand Down Expand Up @@ -104,7 +105,8 @@ def read_staypoints_gpd(
if not pd.api.types.is_datetime64tz_dtype(stps[col]):
stps[col] = _localize_timestamp(dt_series=stps[col], pytz_tzinfo=tz, col_name=col)

assert stps.as_staypoints
# assert validity of staypoints
stps.as_staypoints
return stps


Expand Down Expand Up @@ -159,7 +161,8 @@ def read_triplegs_gpd(
if not pd.api.types.is_datetime64tz_dtype(tpls[col]):
tpls[col] = _localize_timestamp(dt_series=tpls[col], pytz_tzinfo=tz, col_name=col)

assert tpls.as_triplegs
# assert validity of triplegs
tpls.as_triplegs
return tpls


Expand Down Expand Up @@ -229,7 +232,8 @@ def read_trips_gpd(
if not pd.api.types.is_datetime64tz_dtype(trips[col]):
trips[col] = _localize_timestamp(dt_series=trips[col], pytz_tzinfo=tz, col_name=col)

assert trips.as_trips
# assert validity of trips
trips.as_trips
return trips


Expand Down Expand Up @@ -271,7 +275,8 @@ def read_locations_gpd(gdf, user_id="user_id", center="center", mapper={}):
locs = gdf.rename(columns=columns)
locs = locs.set_geometry("center")

assert locs.as_locations
# assert validity of locations
locs.as_locations
return locs


Expand Down Expand Up @@ -332,6 +337,6 @@ def read_tours_gpd(
# if not pd.api.types.is_datetime64tz_dtype(trs[col]):
# trs[col] = localize_timestamp(dt_series=trs[col], pytz_tzinfo=tz, col_name=col)

# assert trs.as_tours
# trs.as_tours
# return trs
pass
35 changes: 22 additions & 13 deletions trackintel/io/postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def read_positionfixes_postgis(conn_string, table_name, geom_col="geom", *args,
pfs = gpd.GeoDataFrame.from_postgis("SELECT * FROM %s" % table_name, conn, geom_col=geom_col, *args, **kwargs)
finally:
conn.close()
assert pfs.as_positionfixes

# assert validity of positionfixes
pfs.as_positionfixes
return pfs


Expand Down Expand Up @@ -106,7 +108,9 @@ def read_triplegs_postgis(conn_string, table_name, geom_col="geom", *args, **kwa
)
finally:
conn.close()
assert pfs.as_triplegs

# assert validity of triplegs
pfs.as_triplegs
return pfs


Expand Down Expand Up @@ -150,7 +154,7 @@ def write_triplegs_postgis(


def read_staypoints_postgis(conn_string, table_name, geom_col="geom", *args, **kwargs):
"""Reads staypoints from a PostGIS database.
"""Read staypoints from a PostGIS database.
Parameters
----------
Expand All @@ -172,13 +176,15 @@ def read_staypoints_postgis(conn_string, table_name, geom_col="geom", *args, **k
engine = create_engine(conn_string)
conn = engine.connect()
try:
pfs = gpd.GeoDataFrame.from_postgis(
stps = gpd.GeoDataFrame.from_postgis(
"SELECT * FROM %s" % table_name, conn, geom_col=geom_col, index_col="id", *args, **kwargs
)
finally:
conn.close()
assert pfs.as_staypoints
return pfs

# assert validity of staypoints
stps.as_staypoints
return stps


def write_staypoints_postgis(staypoints, conn_string, table_name, schema=None, sql_chunksize=None, if_exists="fail"):
Expand Down Expand Up @@ -253,13 +259,14 @@ def read_locations_postgis(conn_string, table_name, geom_col="geom", *args, **kw
)
finally:
conn.close()
assert locs.as_locations

# assert validity of locations
locs.as_locations
return locs


def write_locations_postgis(locations, conn_string, table_name, schema=None, sql_chunksize=None, if_exists="fail"):
"""Stores locations to PostGIS. Usually, this is directly called on a locations
GeoDataFrame (see example below).
"""Store locations to PostGIS. Usually, this is directly called on a locations GeoDataFrame (see example below).
Parameters
----------
Expand Down Expand Up @@ -295,7 +302,7 @@ def write_locations_postgis(locations, conn_string, table_name, schema=None, sql


def read_trips_postgis(conn_string, table_name, *args, **kwargs):
"""Reads trips from a PostGIS database.
"""Read trips from a PostGIS database.
Parameters
----------
Expand All @@ -314,11 +321,13 @@ def read_trips_postgis(conn_string, table_name, *args, **kwargs):
engine = create_engine(conn_string)
conn = engine.connect()
try:
trps = pd.read_sql("SELECT * FROM %s" % table_name, conn, index_col="id", *args, **kwargs)
trips = pd.read_sql("SELECT * FROM %s" % table_name, conn, index_col="id", *args, **kwargs)
finally:
conn.close()
assert trps.as_trips
return trps

# assert validity of trips
trips.as_trips
return trips


def write_trips_postgis(trips, conn_string, table_name, schema=None, sql_chunksize=None, if_exists="fail"):
Expand Down
4 changes: 2 additions & 2 deletions trackintel/preprocessing/positionfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def generate_triplegs(pfs_input, stps_input, method="between_staypoints", gap_th
tpls = tpls.set_geometry("geom")
tpls.crs = pfs.crs

# check the correctness of the generated tpls
assert tpls.as_triplegs
# assert validity of triplegs
tpls.as_triplegs

if case == 2:
pfs.drop(columns="staypoint_id", inplace=True)
Expand Down

0 comments on commit 215c6fc

Please sign in to comment.