Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add von Mises-Fisher distribution #17624

Merged
merged 78 commits into from
Apr 17, 2023
Merged

Conversation

dschmitz89
Copy link
Contributor

@dschmitz89 dschmitz89 commented Dec 17, 2022

Reference issue

Closes gh-17463

What does this implement/fix?

Adds the von Mises-Fisher distribution to scipy.stats. This distribution is the most common analogue of the normal distribution on the unit sphere, likely the most important directional distribution. It is for example available in geomstats and tensorflow.

The random variate generation is basically a numpyfied version of the geomstats implementation.

Additional information

This still has some sharp edges:

  • for some reason the doc formatting via doccer does not work, thats why it is commented out at the moment. This would be the most urgent thing to look at first.

I implemented a fit method for the distribution. So far, no other multivariate distribution has such a method so this should be reviewed carefully.

For a rough idea what this distribution looks like: these are plots of the PDF (top row) and 20 random variates (bottom) for three different concentrations.

VMF_Plots

CC @nguigs : since you authored this distribution for geomstats and I heavily reused parts of it, thought I would ask if you would be willing to help out with review here. Its a quite big PR, so every bit would help. Thanks in advance.

@dschmitz89 dschmitz89 changed the title Vonmisesfisher ENH: add von Mises-Fisher distribution Dec 17, 2022
@j-bowhay j-bowhay added scipy.stats enhancement A new feature or improvement labels Dec 18, 2022
@dschmitz89
Copy link
Contributor Author

@tupui
Copy link
Member

tupui commented Mar 22, 2023

Thanks for linking the discussion. I don't see any replies for the API discussion, but it's ok per our process.

Copy link
Member

@tupui tupui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dschmitz89 Sorry for the delay here. I had another look pulling the changes and playing with it. I think it's in good shape and did not notice strange things. I took the liberty to push some cleanup.

In terms of tests, I think we are missing the 2D case against a reference. Assuming mpmath is ok, then I would suggest adding a test for 2D here.

Something else to consider are testing extreme cases of kappa to have either a concentration around a single point or a perfectly uniform sample. Here we still need to give some guidelines for at least the max value. e.g. I can easily use 1e10 but 1e20 hangs (consider erroring out as well).

On that note, kappa=0 returns NaN with some warnings. I don't think we want that. I would either error out with a nice message, or fallback to a perfectly uniform sample on the hypersphere.

Last point for me would be to document the maths. The 3D and rejection sampling code are a bit obscure as they are. Adding some reference and explanation would help maintain this in the future.

@tupui
Copy link
Member

tupui commented Apr 6, 2023

@chrisb83 I think you worked on vonmises_gen, can I interest you in having a look at this PR?

I think it's in good shape and could be merged once my points above have been addressed. So I am not asking you to review line by line, more check that we did not miss anything obvious with the infra, API or the maths.

@dschmitz89
Copy link
Contributor Author

@dschmitz89 Sorry for the delay here. I had another look pulling the changes and playing with it. I think it's in good shape and did not notice strange things. I took the liberty to push some cleanup.

Thanks!

In terms of tests, I think we are missing the 2D case against a reference. Assuming mpmath is ok, then I would suggest adding a test for 2D here.

This should be covered by the test against the 2D von Mises distribution: see here.

Something else to consider are testing extreme cases of kappa to have either a concentration around a single point or a perfectly uniform sample. Here we still need to give some guidelines for at least the max value. e.g. I can easily use 1e10 but 1e20 hangs (consider erroring out as well).

Thanks for the hint. I was able to squeeze out a higher range for possible values for $\kappa$ when sampling with $d>3$. For even higher values, what is needed is an accurate method to evaluate $\frac{1-x}{1+x}$ at $x\approx 0$ (occurence in the code here). If anyone has an idea, I would gladly implement it.

On that note, kappa=0 returns NaN with some warnings. I don't think we want that. I would either error out with a nice message, or fallback to a perfectly uniform sample on the hypersphere.

The case of $\kappa=0$ is treated extra now: it errors out with a custom message.

Last point for me would be to document the maths. The 3D and rejection sampling code are a bit obscure as they are. Adding some reference and explanation would help maintain this in the future.

I added some more comments. The rejection sampler is a bit involved though and I did not really dig through the reference fully.

@WarrenWeckesser
Copy link
Member

For even higher values, what is needed is an accurate method to evaluate $\frac{1-x}{1+x}$ at $x\approx 0$ (occurence in the code here)

I suspect that particular ratio is not the problem. But in the line that follows, you compute np.log(1. - node ** 2). That subtraction will lose precision when node is near 1, which corresponds to envelop_param (i.e. $x$ in the comment above) being small. You can rewrite that expression to maintain precision:

$$\log\left(1 - \left(\frac{1-x}{1+x}\right)^2\right) = \log\left(\frac{4x}{(1 + x)^2}\right) = \log(4) + \log(x) - 2\textrm{log1p}(x)$$

@dschmitz89
Copy link
Contributor Author

For even higher values, what is needed is an accurate method to evaluate 1−x1+x at x≈0 (occurence in the code here)

I suspect that particular ratio is not the problem. But in the line that follows, you compute np.log(1. - node ** 2). That subtraction will lose precision when node is near 1, which corresponds to envelop_param (i.e. x in the comment above) being small. You can rewrite that expression to maintain precision:

log⁡(1−(1−x1+x)2)=log⁡(4x(1+x)2)=log⁡(4)+log⁡(x)−2log1p(x)

Thanks, this helped a lot. I also reformulated the other logarithmic term containing $\frac{1-x}{1+x}$, now sampling is easily possible for $\kappa=10^{30}$. The fitting method does not have such a high range due to the issues with special.ive from #18088 though.

Copy link
Member

@chrisb83 chrisb83 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tupui I did not work with this distribution before, so I cannot help much. I just did a quick review of the documentation which looks very nice. I have left a few minor comments.

this is definitely a nice addition to SciPy.

----------
mu : array_like
Mean direction of the distribution. Must be a one-dimensional unit
vector of norm 1.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Euclidean norm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Do you think this should be added to the docstrings?

vector of norm 1.
kappa : float
Concentration parameter. Must be positive.
seed : {None, int, np.random.RandomState, np.random.Generator}, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we use the keyword seed for multivariate dists? rvs has random_state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, throughout multivariate, it is called seed, unlike for univariate distributions.

samples = np.squeeze(samples)
return samples

def _rejection_sampling(self, dim, kappa, size, random_state):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity: do you know how well rejection sampling works for this distribution? i guess it gets slower as the number of dimensions increases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, sampling time approximately increases linearly with the number of dimensions. Generating 10,000 samples in 20 dimensions still only takes about 7 ms on my machine.

@dschmitz89
Copy link
Contributor Author

Short summary of the state of this PR from my side: implementation and documetation are done in my opinion. One limitation is that only random variate generation works for $\kappa > 1e9$, the other methods fail due to a problem with the ive (see #18088 ).

Copy link
Member

@tupui tupui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you everyone, approving on my side.

I would propose to ask one last time on the mailing list and if there are no further comments, we would merge one week later.

@dschmitz89
Copy link
Contributor Author

Thank you everyone, approving on my side.

I would propose to ask one last time on the mailing list and if there are no further comments, we would merge one week later.

Sent: https://mail.python.org/archives/list/scipy-dev@python.org/message/HE737ZZUWJYRBJ7KQEYOBHJFNHQQH7KY/

Copy link
Member

@stefanv stefanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice contribution, thank you!

I noticed that one comment from another viewer on the pdf method was resolved without the text changing; just making double sure that's correct, since the docstring still mentions "log".

scipy/stats/_multivariate.py Show resolved Hide resolved
scipy/stats/_multivariate.py Outdated Show resolved Hide resolved
@tupui tupui merged commit 2242bf8 into scipy:main Apr 17, 2023
26 checks passed
@tupui
Copy link
Member

tupui commented Apr 17, 2023

Alright, a week it is an comments have been addressed. Thank you again @dschmitz89 for the PR and pushing this through the finish line. And thanks everyone else for helping out 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement A new feature or improvement scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Add von Mises–Fisher distribution to scipy.stats
8 participants