Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG?] tuple + jaxtyping.Num + torch.Tensor = _BeartypeCallHintPepRaiseDesynchronizationException = 🤮 #312

Open
MaximilienLC opened this issue Dec 1, 2023 · 7 comments

Comments

@MaximilienLC
Copy link

Hey Cecil, hope all is well!

Not sure what happened but your brain child asked me to post it so here I am.

tuple + jaxtyping.Num on a torch.Tensor.

This is on a remote machine and I can not reproduce the error on my local machine.

File "/scratch/mleclei/Dropbox/cneuromax/cneuromax/fitting/deeplearning/litmodule/base.py", line 76, in stage_step
    tupled_batch: tuple[Num[Tensor, " ..."], ...] = tuple(batch)
  File "/scratch/mleclei/Dropbox/cneuromax/cneuromax/fitting/deeplearning/litmodule/base.py", line 76, in <resume in stage_step>
    tupled_batch: tuple[Num[Tensor, " ..."], ...] = tuple(batch)
  File "/usr/local/lib/python3.10/dist-packages/beartype/door/_doorcheck.py", line 108, in die_if_unbearable
    _check_object = _get_object_checker(hint, conf)
  File "/usr/local/lib/python3.10/dist-packages/beartype/door/_doorcheck.py", line 113, in <resume in die_if_unbearable>
    _check_object(obj)
  File "<@beartype(beartype.door._doorcheck._get_object_checker._die_if_unbearable) at 0x2afa9c735ea0>", line 12, in _die_if_unbearable
  File "<@beartype(beartype.door._doorcheck._get_object_checker._die_if_unbearable) at 0x2afa9c735ea0>", line 28, in <resume in _die_if_unbearable>
  File "/usr/local/lib/python3.10/dist-packages/beartype/_decor/error/errormain.py", line 263, in get_beartype_violation
  File "/usr/local/lib/python3.10/dist-packages/beartype/_decor/error/errormain.py", line 333, in <resume in get_beartype_violation>
    raise _BeartypeCallHintPepRaiseDesynchronizationException(
beartype.roar._BeartypeCallHintPepRaiseDesynchronizationException: Function beartype.door._doorcheck._get_object_checker._die_if_unbearable() return (tensor([[[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.424...')) violates type hint tuple[jaxtyping.Num[Tensor, ' ...'], ...], but utility function get_beartype_violation() erroneously suggests this object satisfies this hint. Please report this desynchronization failure to the beartype issue tracker (https://github.com/beartype/beartype/issues) with the accompanying exception traceback and the representation of this object:
(tensor([[[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        ...,


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]]],
       device='cuda:0'), tensor([5, 6, 9, 6, 1, 7, 8, 5, 5, 7, 7, 9, 8, 6, 7, 1, 9, 8, 7, 8, 5, 0, 4, 3,
        6, 1, 2, 8, 1, 6, 5, 1, 5, 7, 0, 5, 0, 2, 5, 0, 1, 0, 1, 8, 4, 7, 3, 6,
        6, 8, 9, 6, 7, 2, 4, 3, 7, 2, 7, 3, 0, 6, 5, 9, 3, 3, 1, 2, 2, 9, 2, 7,
        8, 4, 0, 2, 8, 0, 9, 6, 6, 4, 0, 8, 0, 2, 3, 5, 8, 0, 7, 0, 0, 5, 4, 2,
        9, 7, 8, 7, 1, 9, 7, 2, 9, 7, 1, 3, 5, 2, 8, 6, 0, 6, 3, 7, 4, 9, 9, 4,
        8, 5, 6, 9, 8, 2, 4, 7, 7, 1, 1, 8, 3, 6, 1, 7, 2, 1, 7, 3, 8, 0, 3, 5,
        8, 8, 1, 9, 4, 3, 5, 7, 5, 9, 5, 6, 8, 6, 7, 7, 2, 6, 6, 1, 8, 1, 3, 5,
        2, 8, 6, 1, 1, 2, 6, 4, 6, 2, 5, 3, 4, 4, 4, 1, 2, 6, 4, 6, 8, 6, 7, 9,
        3, 8, 9, 4, 9, 2, 7, 3, 8, 3, 1, 0, 6, 3, 3, 0, 0, 5, 8, 6, 3, 6, 0, 3,
        5, 7, 2, 4, 1, 8, 1, 3, 4, 8, 8, 3, 1, 8, 3, 2, 6, 0, 0, 4, 9, 5, 9, 1,
        6, 6, 7, 5, 4, 9, 3, 5, 6, 5, 7, 1, 2, 1, 3, 0, 1, 8, 2, 3, 4, 7, 7, 3,
        3, 6, 3, 1, 3, 3, 0, 8, 9, 0, 5, 0, 5, 2, 0, 0, 5, 6, 9, 1, 9, 5, 2, 9,
        3, 6, 8, 2, 8, 8, 3, 8, 8, 3, 0, 3, 1, 8, 8, 3, 6, 6, 8, 1, 3, 8, 8, 5,
        9, 6, 2, 8, 3, 7, 6, 2, 4, 8, 4, 7, 6, 3, 4, 4, 2, 4, 1, 3, 5, 6, 6, 9,
        6, 0, 1, 5, 0, 8, 9, 2, 4, 8, 3, 3, 4, 6, 0, 3, 0, 5, 9, 6, 6, 8, 6, 0,
        9, 4, 4, 8, 4, 5, 3, 5, 1, 5, 5, 0, 7, 6, 9, 0, 7, 5, 8, 2, 4, 2, 6, 4,
        4, 8, 9, 4, 6, 0, 6, 5, 9, 9, 5, 9, 1, 5, 3, 5, 9, 1, 5, 5, 9, 4, 7, 0,
        9, 1, 1, 2, 3, 6, 3, 2, 6, 4, 8, 9, 2, 3, 7, 7, 3, 1, 3, 0, 3, 5, 1, 9,
        5, 1, 5, 3, 7, 0, 3, 5, 4, 4, 0, 3, 2, 7, 1, 1, 3, 7, 4, 7, 8, 9, 8, 2,
        7, 6, 8, 3, 9, 1, 1, 9, 6, 9, 8, 3, 6, 6, 8, 3, 6, 0, 3, 1, 7, 2, 4, 2,
        1, 0, 6, 0, 0, 2, 3, 5, 9, 9, 9, 8, 3, 7, 2, 0, 0, 5, 1, 0, 7, 2, 2, 0,
        5, 0, 9, 7, 9, 7, 4, 0, 1, 3, 9, 1, 6, 9, 3, 1, 3, 2, 4, 0, 2, 7, 4, 0,
        1, 9, 9, 8, 6, 6, 7, 6, 2, 0, 3, 0], device='cuda:0')

Have fun (or not)!
Max

@leycec
Copy link
Member

leycec commented Dec 1, 2023

tuple + jaxtyping.Num on a torch.Tensor.

ohohgodsNOOO

This is on a remote machine and I can not reproduce the error on my local machine.

usaywut

In short, you give us nothing but suffering, salty tears, and unreadable exception messages that yield neither consolation nor explication. The above images depict my feelings.

@beartype wants to help you – but @beartype has no idea how to help you. Although @beartype is publicly BFFLs with both jaxtyping and torch, @beartype is also scared speechless by anything involving either of those two packages. I have no paws and I must scream.

@patrick-kidger: Do any of the above moon runes mean anything to you? I assume: "No. Please stop pinging me. Your moon runes are your own problem. jaxtyping cannot be held responsible for jaxtyping."

@MaximilienLC: Would you be able to provide a minimal-reproducible example (MRE) exhibiting this issue? @patrick-kidger wants to help you – but @patrick-kidger has no idea how to help you. The above images depict his feelings, too.

@leycec leycec changed the title [BUG?] _BeartypeCallHintPepRaiseDesynchronizationException [BUG?] tuple + jaxtyping.Num + torch.Tensor = _BeartypeCallHintPepRaiseDesynchronizationException = 🤮 Dec 1, 2023
@patrick-kidger
Copy link
Contributor

Hmm, this is an odd one!
Yeah, can you put together a MWE? In particular with the precise versions of beartype and jaxtyping that you're using.

@MaximilienLC
Copy link
Author

Hey both, thanks for the quick replies, will provide more info in the coming day(s), but days might turn into week(s) due to current time constraints 😅

@MaximilienLC
Copy link
Author

Soooooo, I tried reproducing the error a few times in the same environment but it was never raised again 👻 .
jaxtyping==0.2.24, beartype==0.16.4

Anyways here is a type hint that was checked a few lines prior and passed (method argument):

batch: Num[Tensor, " ..."] | tuple[Num[Tensor, " ..."], ...] | list[Num[Tensor, " ..."]]

and here's the one that did not:

if isinstance(batch, list):
     tupled_batch: tuple[Num[Tensor, " ..."], ...] = tuple(batch)

In jaxtyping notation, the input (batch) was of this format:

list[Float[Tensor, " batch_size 1 28 28"],  Int[Tensor, " batch_size"]]

@leycec
Copy link
Member

leycec commented Dec 5, 2023

Soooooo, I tried reproducing the error a few times in the same environment but it was never raised again 👻 .

ohnoes

Anyways here is a type hint...

Yeah. Nothing suspicious there, I'm afraid. @beartype's support for PEP 484- and 585-compliant list[...] and tuple[...] is really rock-solid and battle-hardened against billions of line of code in other people's code. So, the issue probably isn't there.

Likewise, I trust @patrick-kidger. @patrick-kidger knows things I have only dreamed of in my darkest and most eldritch nightmares. So, the issue probably isn't in jaxtyping either.

You are now thinking to yourself: "@leycec's about to say that I am the problem, isn't he? But I don't wanna be the problem."

Parallelism: Python's Bad Boy Rears Its Ugly Head Again

Well, not quite. Instead, I believe that non-deterministic parallelism is the problem – by which I mean some combination of low-level preemptive multi-threading and/or -processing. Although the Python Global Interpreter Lock (GIL) infamously prohibits preemptive multi-threading from pure-Python, that same constraint doesn't apply to low-level things outside Python like C, C++, CUDA, or Rust extensions.

Clearly, @beartype and jaxtyping are both pure-Python. At least, I think jaxtyping is pure-Python. Actually, I have no idea whether jaxtyping is pure-Python or not. Let's pretend it is. Then, by definition, neither @beartype nor jaxtyping can be performing preemptive multi-threading. @beartype could be performing preemptive multi-processing – but it doesn't, because that would be insane. jaxtyping is likely similar.

That only leaves something in your full stack as the culprit, @MaximilienLC. Something that is neither @beartype nor jaxtyping. This something is possibly PyTorch, possibly your own codebase, or possibly something else entirely. Whatever this something is, this something is probably mutating (i.e., modifying, changing) the contents of your tensors in another thread or process while @beartype and jaxtyping are attempting to concurrently type-check those same tensors.

Let's See If @leycec Knows What He Is Talking About

Thankfully, this is easy to validate on your end. @MaximilienLC, would you mind temporarily guarding all calls to your problematic function with a standard non-reentrant thread lock: e.g.,

from threading import Lock
muh_lock = Lock()

with muh_lock:
    call_muh_problem_func()

If you get deadlocks (i.e., if your app halts at the with muh_lock: statement), then multi-threading is almost certainly the issue. If you get no deadlocks (i.e., if your app still runs but just runs a bit slower), then just try leaving the thread locks for a bit and see if the issue re-arises. If the issue does not re-arise, then multithreading was, indeed, the issue. Thread locks are the solution.

If the issue still re-arises, however, either:

  • Multi-processing is the issue. So, still not @beartype or jaxtyping. Thank Guido!
  • @beartype is the issue. I have no idea how or why, but it's @beartype. I'd definitely need a minimal reproducible example (MRE) to tackle this, however.
  • jaxtyping is the issue. Again, @patrick-kidger will need a MRE.

tl;dr

First, try threading.Lock. If the problem persists, please provide an example. Then panic.

ohnoes

@MaximilienLC
Copy link
Author

Thanks for the detailed answer @leycec. Unfortunately this time I really won't have the bandwidth to test and report in the coming days. I'll edit my answer when I do 😅

@leycec
Copy link
Member

leycec commented Dec 7, 2023

Totally. I'm in a similar leaky boat. The waters are rising and all I have is this blue bucket. 🪣 🛥️

On a more pragmatic note, simply disabling @beatrype on problematic callables and classes might be the sanest short-term "solution" for your team: e.g.,

# Import the requisite machinery.
from beartype import beartype, BeartypeConf, BeartypeStrategy

# Dynamically create a new @nobeartype decorator disabling type-checking.
nobeartype = beartype(conf=BeartypeConf(strategy=BeartypeStrategy.O0))

# Avoid type-checking *ANY* methods or attributes of this class.
@nobeartype
class UncheckedDangerClassIsDangerous(object):
    # This method raises *NO* type-checking violation despite returning a
    # non-"None" value.
    def unchecked_danger_method_is_dangerous(self) -> None:
        return 'This string is not "None". Sadly, nobody cares anymore.'

When times are desperate, even @beartype lies down.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants