Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display funsor terms nicely when it breaks into multiple lines #512

Open
fehiepsi opened this issue Mar 31, 2021 · 2 comments
Open

Display funsor terms nicely when it breaks into multiple lines #512

fehiepsi opened this issue Mar 31, 2021 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@fehiepsi
Copy link
Member

No description provided.

@fehiepsi fehiepsi added the enhancement New feature or request label Mar 31, 2021
@fehiepsi fehiepsi self-assigned this Mar 31, 2021
@fritzo
Copy link
Member

fritzo commented Mar 31, 2021

How about

import black

class Foo:
    def __str__(self):
        ugly = ...
        return black.format_str(ugly, mode=black.FileMode())

for example

>>> import black, funsor, torch
>>> funsor.set_backend("torch")
>>> from funsor.torch.distributions import Multinomial
>>> x = Multinomial(10, torch.tensor([[0.2, 0.8], [0.3, 0.7]]))
>>> print(x)
Multinomial(tensor(10.), tensor([[0.2000, 0.8000],
        [0.3000, 0.7000]]), value)
>>> print(black.format_str(repr(x), mode=black.FileMode()))
Multinomial(
    total_count=tensor(10.0),
    probs=tensor([[0.2000, 0.8000], [0.3000, 0.7000]]),
    value=value,
)

... though I guess this won't work on large tensors that abbreviate like [1, 2, ..., 999].

@fehiepsi
Copy link
Member Author

Nice solution! I tested with abbreviation and it seems to work. Together with yapf, the result is nice.

from funsor import Reals, Tensor, Variable, ops, testing
from funsor.delta import Delta
import black

shape = (3, 2)
point = Tensor(testing.randn(shape))
x = Variable("x", Reals[shape])
actual = Delta("y", point)(y=ops.log(x))
bx = black.format_str(repr(actual), mode=black.FileMode())
print(bx)

from yapf.yapflib.yapf_api import FormatCode
print(FormatCode(bx)[0])

return

Delta(
    (
        (
            "x",
            (
                Tensor(
                    tensor(
                        [
                            [1.1031, 1.4902, 1.0148, ..., 0.1421, 2.2505, 3.2117],
                            [2.6469, 2.3936, 0.5754, ..., 1.1973, 0.2250, 2.7428],
                            [1.7854, 2.3582, 0.1898, ..., 0.4366, 0.3680, 0.4215],
                            ...,
                            [2.0760, 1.0118, 4.1144, ..., 0.7696, 0.2702, 4.7514],
                            [1.7888, 1.4948, 1.5240, ..., 0.9670, 0.4326, 1.2505],
                            [1.8849, 1.2876, 0.7254, ..., 1.6765, 0.3074, 0.9956],
                        ]
                    )
                ),
                Tensor(tensor(30.6611)),
            ),
        ),
    )
)

Delta(((
    "x",
    (
        Tensor(
            tensor([
                [1.1031, 1.4902, 1.0148, ..., 0.1421, 2.2505, 3.2117],
                [2.6469, 2.3936, 0.5754, ..., 1.1973, 0.2250, 2.7428],
                [1.7854, 2.3582, 0.1898, ..., 0.4366, 0.3680, 0.4215],
                ...,
                [2.0760, 1.0118, 4.1144, ..., 0.7696, 0.2702, 4.7514],
                [1.7888, 1.4948, 1.5240, ..., 0.9670, 0.4326, 1.2505],
                [1.8849, 1.2876, 0.7254, ..., 1.6765, 0.3074, 0.9956],
            ])),
        Tensor(tensor(30.6611)),
    ),
),))

This could be the right path for us. I'll try to see if we can avoid extra dependency...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants