Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BinaryPrecisionRecallCurve for large datasets (>100 million samples) #1309

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

jpcbertoldo
Copy link

@jpcbertoldo jpcbertoldo commented Nov 3, 2022

What does this PR do?

This PR provides an alternative to this warning:

rank_zero_warn(
    "Metric `PrecisionRecallCurve` will save all targets and predictions in buffer."
    " For large datasets this may lead to large memory footprint."
)

Context

Currently, BinaryPrecisionRecallCurve (and related methods by consequence) has two operation modes, which I will call "computed-thresholds" (arg thresholds=None) and "given-thresholds" (otherwise).

"computed-thresholds"

All possible thresholds in the sample ares used in compute(). It has a high memory consumption if the number of instances is high because all the preds and targets are kept in memory during the updates.

How much is "high" memory consumption?

I will consider 100 million samples the order of magnitude where "high" starts.
This roughly corresponds to ~750Mb of for the tensor preds under the "computed-thresholds" mode: 1e8 (samples) * 4 (bytes/float32) = 375Mb.

In my research I'd like to compute roc curves of pixel-wise classifications with 1024 x 1024 resolution images, and I have ~2000 in test, which would account for 1024 * 1024 * 2000 * 4 (byte/float32) ~= 7.8Gb at least .

"given-thresholds"

Thresholds are pre-defined, so all possible binarizations are known in advance. At each update all are tested, giving as many confusion matrices.

The memory consumption is low because and a function of the number of thresholds
If num_thresholds = 10e6, then it uses roughly 10e6 * (2 *2) (confmat shape) * 8 (bytes/long) ~= 300mb which is much lower even with 1 million thresholds.
However, it requires the user to know meaningful threshols in advance, which is often not the case.

A possible (inconvenient) solution would be to estimate the thresholds on a first call then compute the curve.

Solution

I propose to create a hybrid strategy where the update switches from "computed-thresholds" to "given-thresholds" dinamically.

Given a budget (number of instances), the keeps preds and targets in the state until that budget is reached, then estimates meaningful thresholds and compute the confusion matrices at that point. From there it behaves like "given-thresholds" mode.

API sketch

# the arg `thresholds` can be a str 
# `budget` is the max number of instances seen before the mode switch
metric = BinaryPrecisionRecallCurve(thresholds="dynamic", budget=10e8)  # 100 million for instance

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@SkafteNicki
Copy link
Member

Hi @jpcbertoldo, thanks for proposing this enhancement (I have also seen your thread on slack)
I do like the idea, but I wonder if it is possible to still fit this feature into the thresholds argument, so we do not have to change the outer interface again. We could reserve str input to this feature such that:

metric = BinaryPrecisionRecallCurve(thresholds="12345")
metric = BinaryPrecisionRecallCurve(thresholds="250mb")
metric = BinaryPrecisionRecallCurve(thresholds="3gb")

all enabled this feature. What do you think about that?

@jpcbertoldo
Copy link
Author

Hi @jpcbertoldo, thanks for proposing this enhancement (I have also seen your thread on slack) I do like the idea, but I wonder if it is possible to still fit this feature into the thresholds argument, so we do not have to change the outer interface again. We could reserve str input to this feature such that:

metric = BinaryPrecisionRecallCurve(thresholds="12345")
metric = BinaryPrecisionRecallCurve(thresholds="250mb")
metric = BinaryPrecisionRecallCurve(thresholds="3gb")

all enabled this feature. What do you think about that?

Yeah makes sense. I was actually thinking of considering just the case threshols="200mb", I think it makes a lot more sense than thinking about numbers of instances.

@jpcbertoldo
Copy link
Author

@SkafteNicki by the way, I just saw this

New Features:

  1. Submit a github issue - describe what is the motivation of such feature (adding the use case or an example is helpful).
  2. Let's discuss to determine the feature scope.

Could we skip this step for this time? : )

@jpcbertoldo jpcbertoldo force-pushed the jpcbertoldo/roc-for-large-datasets branch from b98edf7 to d5a8309 Compare November 3, 2022 13:31
@SkafteNicki
Copy link
Member

@SkafteNicki by the way, I just saw this

New Features:

  1. Submit a github issue - describe what is the motivation of such feature (adding the use case or an example is helpful).
  2. Let's discuss to determine the feature scope.

Could we skip this step for this time? : )

That is normally how we like it. But since you already had implemented some, no need to open a issue about this. It is recommended because we do not want people to implement a lot before talking with us, because we may decide it is not a feature we want and therefore their work would go to waste.

@SkafteNicki SkafteNicki added the enhancement New feature or request label Nov 3, 2022
@SkafteNicki SkafteNicki added this to In progress in Classification refactor via automation Nov 3, 2022
@SkafteNicki SkafteNicki added this to the v0.11 milestone Nov 3, 2022
Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you made any thoughts on extending this to multiclass / multilabel?

@@ -14,6 +14,7 @@

from typing import List, Optional, Sequence, Tuple, Union

import humanfriendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really do not want to introduce any new dependencies.
The conversion from mb and gb seems to be something we can do ourself?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that one coming haha.

So, i put it anyway because it feels like the kind of functionality pruned to silly mistakes while a tiny library like this has it neatly packed in.

I can try to make a minimal version of it based on the library's source code. Is that a better solution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make it conditional, if user already have it, then use it

if module_available("humanfriendly"):
    import humanfriendly
else:
    humanfriendly = None

Comment on lines +132 to +137
class _ComputationMode(Enum):
"""Internal state of the dynamic mode."""

BINNED = "binned"
NON_BINNED = "non-binned"
NON_BINNED_DYNAMIC = "non-binned-dynamic"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems weird to me having a class inside a class def?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of rare indeed but i usually do this in such cases where it is strictly only used internally 🤷‍♂️

Should i pop it out?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please

Comment on lines +56 to +66
def _validate_memory_budget(budget: int):
if budget <= 0:
raise ValueError("Budget must be larger than 0.")

if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES:
warnings.warn(
f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). "
"The dynamic mode is recommended for bigger samples."
)

return budget
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we only have the thresholds argument i guess all this logic can be moved to the _adjust_threshold_arg function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the multiclass/label cases the estimation (number of samples) <-> (memory consumption) would be different.

I don't have a very strong opinion on this, i will put some more thought on the multi* cases first.

@jpcbertoldo
Copy link
Author

Have you made any thoughts on extending this to multiclass / multilabel?

Just as much as in my reply to your comment 😬.

I'll be putting some effort on that this weekend :)

@SkafteNicki SkafteNicki modified the milestones: v0.11, v0.12 Nov 18, 2022
@jpcbertoldo
Copy link
Author

@SkafteNicki i'm having a hard time to make my self available to invest some time on this.
Maybe could you give me a hand and tell me what/where I have to change things to adapt it for the multilabel case?

@Borda
Copy link
Member

Borda commented Dec 23, 2022

@stancld could you help here to finish this PR? 🦦

from typing import Any, List, Optional, Tuple, Union

import humanfriendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have this as optional

Comment on lines +132 to +137
class _ComputationMode(Enum):
"""Internal state of the dynamic mode."""

BINNED = "binned"
NON_BINNED = "non-binned"
NON_BINNED_DYNAMIC = "non-binned-dynamic"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please

Comment on lines +138 to +143
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
elif thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
else:
return BinaryPrecisionRecallCurve._ComputationMode.BINNED
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
elif thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
else:
return BinaryPrecisionRecallCurve._ComputationMode.BINNED
if isinstance(thresholds, str):
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC
if thresholds is None:
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED
return BinaryPrecisionRecallCurve._ComputationMode.BINNED

@@ -14,6 +14,7 @@

from typing import List, Optional, Sequence, Tuple, Union

import humanfriendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make it conditional, if user already have it, then use it

if module_available("humanfriendly"):
    import humanfriendly
else:
    humanfriendly = None

if isinstance(thresholds, int):
thresholds = torch.linspace(0, 1, thresholds, device=device)
if isinstance(thresholds, list):
thresholds = torch.tensor(thresholds, device=device)
if isinstance(thresholds, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(thresholds, str):
if isinstance(thresholds, str) and humanfriendly:

@Borda
Copy link
Member

Borda commented Feb 28, 2023

@jpcbertoldo, how is it going here? I think we are on good path... :)

@jpcbertoldo
Copy link
Author

@jpcbertoldo, how is it going here? I think we are on good path... :)

Hi @Borda, I lost track of this since a while (vacation, other priorities...). I will try to come back to it soon!

@SkafteNicki SkafteNicki modified the milestones: v0.12, future Mar 31, 2023
@jpcbertoldo
Copy link
Author

@Borda I was thinking if we could do a simpler solution for this by computing the metric in 2 rounds ("epochs").

1st round: just find the min/max, then, with a given maximum number of points, linearly -- or some smarter way ? -- space the thresholds between the min/max.

2nd round: compute the actual TPR/FPRs in an online way (because now the thresholds are known in advance).


An alternative for the 1st round: instead of just min/max, keep track of unique values, which may provide information for eventually having parts of the threshold range more or less dense. The set of unique values could be in float16 instead of float32 to increase the redundancy of values (i'm guessing it shouldn't affect so much the accuracy of the AUC).

@Borda
Copy link
Member

Borda commented Aug 8, 2023

@jpcbertoldo apology for the late reply...
@SkafteNicki what do you think about the suggestion above?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

3 participants