Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array protocol dispatch methods to top-level Funsor class #546

Open
eb8680 opened this issue Aug 3, 2021 · 2 comments
Open

Add array protocol dispatch methods to top-level Funsor class #546

eb8680 opened this issue Aug 3, 2021 · 2 comments

Comments

@eb8680
Copy link
Member

eb8680 commented Aug 3, 2021

Now that PyTorch supports tensor subtyping and function overloading with __torch_function__, should we add __array_function__ and __torch_function__ methods to funsor.terms.Funsor to allow evaluation of (some) PyTorch/Numpy code on Funsors?

Here is the meat of a Funsor.__torch_function__ implementation, modulo handling of edge cases; __array_function__ for the Numpy backend would be very similar:

class Funsor:
    ...
    def __torch_function__(self, func, types, args=(), kwargs=None):
        # exploit our op registry: ops should know how to handle and convert their arguments
        try:
            op = getattr(funsor.ops, func.__name__)
        except AttributeError:
            op = funsor.ops.make_op(func). # handle e.g. nn.Module or dist.Transform instances
        return op(*args, **kwargs)

The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful to_data/to_funsor primitives in pyro.contrib.funsor, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for special torch.Tensor subclasses that duplicate Funsor broadcasting semantics.

@eb8680 eb8680 mentioned this issue Aug 3, 2021
2 tasks
@ordabayevy
Copy link
Member

I like the idea quite a lot! It might simplify things in funsor and make look cleaner. My current understanding is that __torch_function__ will replace all Funsor.ops (such as Funsor.__add__, Funsor.sum, etc)? And contrib.funsor will calculate everything as Funsors during model execution instead of delegating it to TraceMessenger and converting it the last moment?

@eb8680
Copy link
Member Author

eb8680 commented Aug 4, 2021

I don't think it would replace the basic Python operator overloads, but array-specific methods like sum() could probably be removed in favor of these generic methods.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants