Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nucleus (top-p) sampling #1347

Open
belerico opened this issue Apr 24, 2024 · 2 comments
Open

Nucleus (top-p) sampling #1347

belerico opened this issue Apr 24, 2024 · 2 comments
Labels
enhancement New feature or request generation

Comments

@belerico
Copy link
Contributor

Nucleus sampling (top-p sampling in HF) is a dynamic sampling strategy that "truncat[es] the unreliable tail of the probability distribution, sampling from the dynamic nucleus of tokens containing the vast majority of the probability mass.".
It can be easily implemented in the sample method like this:

def sample(
    logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None
) -> torch.Tensor:
    logits = logits[0, -1]
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
    if top_p is not None:
        sorted_logits, sorted_indices = torch.sort(logits, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
        sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, float("-inf"))
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0:
        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)

I can open a PR with this add if this is considered useful

@rasbt
Copy link
Collaborator

rasbt commented Apr 24, 2024

Thanks for suggesting and offering to contribute

In short, instead of selecting a hard number of samples to like in top k, it selects the number of samples such that they don't exceed a threshold p. I think this is a popular standard technique and could potentially be added as an option for litgpt chat analogous and in addition to the top_k setting.
It would be a nice contribution.
What do you think @awaelchli @carmocca ?

Screenshot 2024-04-24 at 8 55 01 AM

@carmocca
Copy link
Contributor

I agree

@carmocca carmocca added enhancement New feature or request generation labels Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request generation
Projects
None yet
Development

No branches or pull requests

3 participants