Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post-mortem: why an easy workflow was horribly non-performant, and what we could do to make it easier for users to write fast dask code #301

Open
crusaderky opened this issue Jan 20, 2023 · 9 comments

Comments

@crusaderky
Copy link

crusaderky commented Jan 20, 2023

Executive summary

Today, the user experience of a typical novice to intermediate dask.dataframe user can be very poor. Building a workflow that is supposedly very straightforward can result in an extremely non-performant output with a tendency to randomly kill off workers. At the end of this post you'll find 13 remedial actions, 10 of which can be sensibly achieved in a few weeks, which can drastically improve the user experience.

Introduction

I recently went through a demo notebook, written by a data scientist, whose purpose is to showcase dask.dataframe to new dask users through a real-life use case. The notebook does what I would call light data preprocessing on a 40 GiB parquet dataset of NYC taxis, with the purpose of later feeding them into a machine learning algorithm.

The first time I ran it, the notebook ran in 25 minutes and required hosts mounting a bare minimum of 32 GiB RAM each. After a long day of tweaking it, I brought it down to 2 minutes runtime and 8 GiB per host RAM requirements.

The problem is this: the workflow implemented by the notebook is not rocket science. Getting to a performant implementation is something you would expect to be a stress-free exercise for an average data scientist; instead it took a day's worth of anger-debugging from a dask maintainer to make it work properly.

This thread is a post-mortem of my experience with it, detailing all the roadblocks that both the original coder and myself hit, with proposals on how to make it a smooth sail in the future.

The algorithm

What's implemented is a standard machine learning pre-processing flow:

  1. Load a 16 billion rows parquet dataset from s3
  2. Discard unneeded columns
  3. Discard rows containing malformed data
  4. Discard rows containing outliers
  5. Join with a tiny pandas.Dataframe (265 rows) containing domain mapping
  6. Convert all domain-based columns to categories, with global domains
  7. Write to parquet on s3

Implementation

Data loading and column manipulation

This first part loads up the dataframe, generates a few extra columns as a function of other columns, and drops unnecessary columns.
The dataset is publicly accessible - you may reproduce this on your own.

Original code

client = distributed.Client(...)
ddf = dd.read_parquet(
    "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/processed_data.parquet"
)
print(f"size of the total dataset is:  {len(ddf.index)}")

ddf = ddf.assign(accessible_vehicle=1)
ddf.accessible_vehicle = ddf.accessible_vehicle.where(ddf.on_scene_datetime.isnull(), 0)
ddf = ddf.assign(pickup_month=ddf.pickup_datetime.dt.month)
ddf = ddf.assign(pickup_dow=ddf.pickup_datetime.dt.dayofweek)
ddf = ddf.assign(pickup_hour=ddf.pickup_datetime.dt.hour)

ddf = ddf.drop(
    columns=[
        "on_scene_datetime",
        "request_datetime",
        "pickup_datetime",
        "dispatching_base_num",
        "originating_base_num",
        "shared_request_flag",
        "shared_match_flag",
        "dropoff_datetime",
        "base_passenger_fare",
        "bcf",
        "sales_tax",
        "tips",
        "driver_pay",
        "access_a_ride_flag",
        "wav_match_flag",
        "wav_request_flag",
    ]
)
ddf = ddf.reset_index(drop=True)

ddf["airport_fee"] = ddf["airport_fee"].replace("None", 0)
ddf["airport_fee"] = ddf["airport_fee"].replace("nan", 0)
ddf["airport_fee"] = ddf["airport_fee"].astype(float)
ddf["airport_fee"] = ddf["airport_fee"].fillna(0)

To an intermediate user's eye, this looks OK.
But it is very, very bad:

  1. All string data is read in Python object format, which is excruciatingly slow to process. Switching to PyArrow was not painless: Config option dataframe.dtype_backend: pyarrow doesn't seem to work dask#9840.
  2. That innocent-looking print statement on the second line reads the whole dataset into memory and then discards it.
  3. The code reads the whole thing in memory and then drop unneeded columns. However, parquet allows to efficiently cherry-pick individual columns, while leaving the rest untouched on disk.
  4. Last but not least: the partitions on disk are very dishomogeneous, with the smallest being 22 MiB and the largest weighting a whopping 836 MiB. This is what caused the memory requirements of 32 GiB per host. However, if you repartition() them into smaller chunks, the whole thing becomes a lot more manageable, even if the initial load still requires to compute everything at once.

Revised code

Before starting the cluster:
(this is Coiled-specific. Other clusters will require you to manually set the config on all workers).

dask.config.set({"dataframe.dtype_backend": "pyarrow"})
client = distributed.Client(...)

# Workaround to https://github.com/dask/dask/issues/9840
from distributed import WorkerPlugin
class SetPandasOptions(WorkerPlugin):
    def setup(self, worker):
        pd.set_option("string_storage", "pyarrow")
pd.set_option("string_storage", "pyarrow")  # Set on the client
_ = client.register_worker_plugin(SetPandasOptions())  # Set on the workers
# End workaround

ddf = dd.read_parquet(
    "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/processed_data.parquet",
    index=False,
    columns=[
        "hvfhs_license_num",
        "PULocationID",
        "DOLocationID",
        "trip_miles",
        "trip_time",
        "tolls",
        "congestion_surcharge",
        "airport_fee",
        "wav_request_flag",
        "on_scene_datetime",
        "pickup_datetime",
    ],
)
ddf = ddf.repartition(partition_size="100MB")
ddf = ddf.assign(
    accessible_vehicle=ddf.on_scene_datetime.isnull(),
    pickup_month=ddf.pickup_datetime.dt.month,
    pickup_dow=ddf.pickup_datetime.dt.dayofweek,
    pickup_hour=ddf.pickup_datetime.dt.hour,
)
ddf = ddf.drop(columns=["on_scene_datetime", "pickup_datetime"])
ddf["airport_fee"] = ddf["airport_fee"].replace("None", 0).astype(float).fillna(0)

Note that the call to repartition reads the whole thing in memory and then discards it. This takes a substantial amount of time, but it's the best I could do. It is wholly avoidable though. I also feel that a novice user should not be bothered with having to deal with oversized chunks themselves:

When repartition() is unavoidable, because it's in the middle of a computation, it could avoid computing everything:

As for PyArrow strings: I am strongly convinced they should be the default.
I understand that setting PyArrow strings on by default would cause dask to deviate from pandas. I think it's a case where the deviation is worthwhile - pandas doesn't need to cope with the GIL and serialization!

Being forced to think ahead with column slicing is another interesting discussion. It could be avoided if each column was a separate dask key. For the record, this is exactly what xarray.Dataset does. Alternatively, High level expressions (dask#7933) would allow rewriting the dask graph on the fly. With either of these changes, the columns you don't need would never leave the disk (as long as you set chunk_size=1). Also, it would mean that len(ddf) would have to load a single column (seconds) instead of the whole thing (minutes).

I appreciate that introducing splitting by column in dask.dataframe would be a very major effort - but I think it's very likely worth the price.

A much cheaper fix to len():

Drop rows

After column manipulation, we move to row filtering:

Original code

ddf = ddf.dropna(how="any")

original_rowcount = len(ddf.index)

# Remove outliers
lower_bound = 0
Q3 = ddf["trip_time"].quantile(0.75)
print(f"Q3 is:  {Q3.compute()}")
upper_bound = Q3 + (1.5 * (Q3 - lower_bound))
print(f"Upper bound is:  {upper_bound.compute()}")

ddf = ddf.loc[(ddf["trip_time"] >= lower_bound) & (ddf["trip_time"] <= upper_bound)]

ddf = ddf.repartition(partition_size="100MB").persist()
print(
    "Fraction of dataset left after removing outliers:",
    len(ddf.index) / original_rowcount,
)

This snippet recomputes everything so far (load from s3 AND column preprocessing), 😶 FIVE 😶 TIMES 😶:

  1. It performs yet another call to len(ddf.index).
  2. It computes the Q3 on a column
  3. It computes upper_bound

Again, if the graph was split by columns or was rewritten on the fly by high level expressions, these three would be much less of a problem.

  1. repartition(partition_size=...) under the hood calls compute() and then discards
    everything.
  2. persist() recomputes everything from the beginning one more time.

The one thing that is inexpensive is the final call to len(ddf.index), because it's immediately after a persist().

Revised code

ddf = ddf.dropna(how="any")

# Remove outliers
lower_bound = 0
Q3 = ddf['trip_time'].quantile(0.75)
upper_bound = Q3 + (1.5 * (Q3 - lower_bound))
ddf = ddf.loc[(ddf["trip_time"] >= lower_bound) & (ddf["trip_time"] <= upper_bound)]

The repartition() call is no longer there, since I already called it once and now partitions are guaranteed to be smaller or equal to before.

Like before, I've outright removed the print() statements. If I had to retain them, I would push them further down, immediately after a call to persist(), so that the computation is only done once:

original_rowcount = ddf.size()
# ... drop rows
new_rowcount = ddf.size()
# ...
ddf = ddf.persist()
original_rowcount, new_rowcount = client.compute(original_rowcount, new_rowcount)
print(
    "Fraction of dataset left after removing outliers:",
    new_rowcount / original_rowcount
)

Joins

Download the "Taxi Zone Lookup Table (CSV) from https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page, and save it to data/taxi+_zone_lookup.csv.

Original code

taxi_df = pd.read_csv("data/taxi+_zone_lookup.csv", usecols=["LocationID", "Borough"])

ddf = dd.merge(ddf, taxi_df, left_on="PULocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "PUBorough"})
ddf = ddf.drop(columns="LocationID")

ddf = dd.merge(ddf, taxi_df, left_on="DOLocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "DOBorough"})
ddf = ddf.drop(columns="LocationID")


BOROUGH_MAPPING = {
    "Manhattan": "Superborough 1",
    "Bronx": "Superborough 1",
    "EWR": "Superborough 1",
    "Brooklyn": "Superborough 2",
    "Queens": "Superborough 2",
    "Staten Island": "Superborough 3",
    "Unknown": "Unknown",
}


def make_cross_borough_cat(df):
    PUSuperborough = [BOROUGH_MAPPING.get(i) for i in df.PUBorough.tolist()]
    DOSuperborough = [BOROUGH_MAPPING.get(i) for i in df.DOBorough.tolist()]
    PUSuperborough_DOSuperborough_Pair = [
        f"{i}-{j}" for i, j in zip(PUSuperborough, DOSuperborough)
    ]
    return df.assign(PUSuperborough_DOSuperborough=PUSuperborough_DOSuperborough_Pair)


ddf = ddf.map_partitions(lambda df: make_cross_borough_cat(df))

Again, this sneaks in more pure-python strings into the dataframe.
This would be solved by force-casting them to string[pyarrow] in from_pandas() (which is called under the hood by merge()):

It also does a typical mistake of using pure-python code for the sake of simplicity instead of going through a pandas join. While this is already a nonperformant in pandas, when you move to dask you have the problem that it hogs the GIL.

The sensible approach to limiting this problem would be to show the user a plot of how much time they spent on each task:

In addition to the above ticket, we need documentation & evangelism to teach the users that to debug a non-performant workflow they can look at Prometheus, and at which metrics they should look at first. At the moment, Prometheus is a single, hard-to-notice page in the dask docs.

The second join is performed on a much bigger dataset than necessary. There's no real solution to this - this is an abstract algorithmic optimization that the developer could have noticed themselves.

Finally, the PUSuperborough and DOSuperborough columns could be dropped. I know they're unnecessary by reading the next section below about categorization. Again, nothing we (dask devs) can do here.

Revised code

taxi_zone_lookup = pd.read_csv(
    "data/taxi+_zone_lookup.csv", usecols=["LocationID", "Borough"]
)
BOROUGH_MAPPING = {
    "Manhattan": "Superborough 1",
    "Bronx": "Superborough 1",
    "EWR": "Superborough 1",
    "Brooklyn": "Superborough 2",
    "Queens": "Superborough 2",
    "Staten Island": "Superborough 3",
    "Unknown": "Unknown",
}

taxi_zone_lookup["Superborough"] = [
    BOROUGH_MAPPING[k] for k in taxi_zone_lookup["Borough"]
]
taxi_zone_lookup = taxi_zone_lookup.astype(
    {"Borough": "string[pyarrow]", "Superborough": "string[pyarrow]"}
)

ddf = dd.merge(ddf, taxi_zone_lookup, left_on="PULocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "PUBorough", "Superborough": "PUSuperborough"})
ddf = ddf.drop(columns="LocationID")

ddf = dd.merge(ddf, taxi_zone_lookup, left_on="DOLocationID", right_on="LocationID", how="inner")
ddf = ddf.rename(columns={"Borough": "DOBorough", "Superborough": "DOSuperborough"})
ddf = ddf.drop(columns="LocationID")

ddf["PUSuperborough_DOSuperborough"] = ddf.PUSuperborough.str.cat(
    ddf.DOSuperborough, sep="-"
)
ddf = ddf.drop(columns=["PUSuperborough", "DOSuperborough"])

Categorization

This dataset is going to be fed into a machine learning engine, so everything that can be converted into a domain, should:

ddf = ddf.repartition(partition_size="100MB").persist()

categories = [
    "hvfhs_license_num",
    "PULocationID",
    "DOLocationID",
    "accessible_vehicle",
    "pickup_month",
    "pickup_dow",
    "pickup_hour",
    "PUBorough",
    "DOBorough",
    "PUSuperborough_DOSuperborough",
]
ddf[categories] = ddf[categories].astype("category")
ddf = ddf.categorize(columns=categories)

The first call to repartition() computes everything since the previous call to persist() - this includes the pure-python joins - and then discards it. Then, persist() computes it again.

The call to astype("category") is unnecessary. categorize(), while in this case is OK since it's (accidentally?) just after a
persist(), is a major pain point on its own:

Revised code

categories = [
    "hvfhs_license_num",
    "PULocationID",
    "DOLocationID",
    "wav_request_flag",
    "accessible_vehicle",
    "pickup_month",
    "pickup_dow",
    "pickup_hour",
    "PUBorough",
    "DOBorough",
    "PUSuperborough_DOSuperborough",
]

# Read https://github.com/dask/dask/issues/9847
ddf = ddf.astype(dict.fromkeys(categories, "category"))
ddf = ddf.persist()
ddf = ddf.categorize(categories)

Disk write

Original code

    ddf = ddf.repartition(partition_size="100MB")
    ddf.to_parquet(
        "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/feature_table.parquet",
        overwrite=True,
    )

The final call to repartition() recomputes everything since persist() and then discards it - in this case, not that much.

Revised code

ddf = ddf.persist().repartition(partition_size="100MB")

# Workaround to https://github.com/apache/arrow/issues/33727
ddf = ddf.astype(
    {
        col: pd.CategoricalDtype(dt.categories.astype(object))
        for col, dt in ddf.dtypes.items()
        if isinstance(dt, pd.CategoricalDtype)
        and dt.categories.dtype == "string[pyarrow]"
    }
)

ddf.to_parquet(
    "s3://coiled-datasets/prefect-dask/nyc-uber-lyft/feature_table.parquet",
    overwrite=True,
)

Action points recap

Low effort

Intermediate effort

  • PyArrow should be opt-out, not opt-in. This requires fixing all the outstanding pain points with it.
  • Push the use of prometheus as the go-to performance debugging tool

High effort

Very high effort

@crusaderky crusaderky changed the title Post-mortem: why an easy workflow was horribly unperformant, and what we could do to make it easier for user to write fast dask code Post-mortem: why an easy workflow was horribly non-performant, and what we could do to make it easier for users to write fast dask code Jan 20, 2023
@mrocklin
Copy link
Member

@crusaderky this is great. Maybe a good blogpost?

@gjoseph92
Copy link

@crusaderky this is an excellent and very valuable writeup. Thanks for taking the time to do this.

Being forced to think ahead with column slicing is another interesting discussion. It could be avoided if each column was a separate dask key

We've also talked about doing column pruning automatically. This is one of the optimizations we hope would be enabled by high level expressions / revising high level graphs dask/dask#7933 cc @rjzamora. Might be worth adding that to the list.

@crusaderky
Copy link
Author

We've also talked about doing column pruning automatically. This is one of the optimizations we hope would be enabled by high level expressions / revising high level graphs dask/dask#7933 cc @rjzamora. Might be worth adding that to the list.

Updated post.

@martindurant
Copy link
Member

We have mentioned this before, but dask-awkward has a nice example of following the metadata through many layers to be able to prune column loads. It's enabled by the rich awkward "typetracer" (a no-data array) that they implemented specifically for this purpose and just works through many operations where fake data or zero-length array might not.

@gjoseph92
Copy link

I'm curious about a couple human / API design aspects of this that I think are also worth looking into.

Two things we see a lot of, that don't need to be there:

  1. repeated re-computation
  2. persist

There are 4 unnecessary recomputations for print statements, and 3 recomputations for repartition operations.

The print statements I find really interesting. I'm going to assume the author didn't realize how much time these added. (As in, they weren't so important for the workflow that they were worth leaving in.) Just removing those, without any of the other changes or fixes you've suggested, would still speed up the code a ton!

So how could we make it easier to recognize how expensive this is?

  1. We could track recently-released keys on the scheduler. Using some heuristic, if you submit a graph that contains a lot of recently-released keys, we send a warning message to the client that you may be recomputing data, and suggesting changes to make (use persist, don't call compute multiple times, etc.). I'm skeptical how useful/effective this would be.

  2. In general, it's hard to map what dask is doing back to user code. When dask is running a task, you don't know what line number in your code that task came from. This can make debugging dask performance feel pretty impenetrable. Now obviously here, running any Python profiler on the script would highlight these compute calls.

    But perhaps because it's so impenetrable, people don't think to even try to understand or profile performance? (Not to mention that many dask users might not be familiar with profiling in the first place.) So the more we make it easier to understand performance, the more users will feel empowered to think about it and tweak things themselves. I imagine that right now, a lot of users call compute and hope for the best. How the time is spent and why during that compute feels like a black box (or a bunch of colors flying around on a dashboard that are hard to understand).

    I wonder if we could make something (on the dashboard?) that gives a profile of your own code and how much dask work each line does. I'm imagining maybe line-profiler style, like the scalene GUI? We wouldn't actually run a profiler on the client—we'd basically have a symbol table mapping from dask keys to lines in user code, then show runtime / memory? / task count / transfers / spill / etc aggregated over tasks for that line. We'd also have a way of splitting it up by compute call, which would highlight repeated computations. (Note that a natural place to put this symbol table would also be High Level Expressions dask#7933).

    This is a big thing that would be very generally useful far beyond unnecessary print statements. But if we had it, it would probably help you find the unnecessary print statements.

Another thought: print statements are useful. Even in a world where the author recognized these repeated computes were adding a lot of time, they still might want to print out Q3 and original_rowcount just to see them. Maybe we could have a dask.log or dask.print function to make it easier to log delayed values without calling compute?

Finally, persist seems to be sprinkled liberally across the original code. To me, persist is a bad idea in most cases because it means you're now manually managing memory: dask can't stream operations in constant memory, and release chunks that are no longer needed, because you've pinned them all. But I feel like I see many users add persists all over their code, and I want to know why. Is there something in the docs? It feels like a knob people reach for when they're not happy with performance (because performance is a black box, see above), but why?

@rjzamora
Copy link
Member

It feels like a knob people reach for when they're not happy with performance (because performance is a black box, see above), but why?

I agree that using persist can be problematic in many cases, but is also something that the documentation recommends: https://docs.dask.org/en/stable/best-practices.html#persist-when-you-can

@jakirkham
Copy link
Member

This is a really nice write-up!

What do people here think about caching? Agree when computations happen could be surprising to an end user. At a minimum, workflows could avoid repeating a computation unnecessarily.

Frequently point users (even experienced developers) to graphchain. Perhaps Dask would benefit by baking this in.

cc @lsorber (in case you have thoughts on any of this)

@GenevieveBuckley
Copy link
Collaborator

@crusaderky this is great. Maybe a good blogpost?

It'd also make a pretty fascinating short talk for one of the dask demo days

@crusaderky
Copy link
Author

About the pure-python code: @gjoseph92 I wonder if it would be feasible to offer a simple flag in coiled.Cluster, e.g. low_level_profile=True, which would let us populate a grafana plot of time spent waiting in GIL contention, broken down by user function (the lowest level callable directly visible in the dask graph? Of course enabling such a flag would come at a ~2x performance cost. I don't know if we're capable of real-time C-level profiling though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants