You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Nucleus sampling (top-p sampling in HF) is a dynamic sampling strategy that "truncat[es] the unreliable tail of the probability distribution, sampling from the dynamic nucleus of tokens containing the vast majority of the probability mass.".
It can be easily implemented in the sample method like this:
defsample(
logits: torch.Tensor, temperature: float=1.0, top_k: Optional[int] =None, top_p: Optional[float] =None
) ->torch.Tensor:
logits=logits[0, -1]
# optionally crop the logits to only the top k optionsiftop_kisnotNone:
v, i=torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisionslogits=torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally crop the logits to smallest set of logits with a cumulative probability above top_piftop_pisnotNone:
sorted_logits, sorted_indices=torch.sort(logits, descending=False)
cumulative_probs=sorted_logits.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove=cumulative_probs<= (1-top_p)
indices_to_remove=sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
logits=logits.masked_fill(indices_to_remove, float("-inf"))
# optionally scale the logits and sample from a probability distributioniftemperature>0.0:
probs=torch.nn.functional.softmax(logits/temperature, dim=-1)
returnmultinomial_num_samples_1(probs)
returntorch.argmax(logits, dim=-1, keepdim=True)
I can open a PR with this add if this is considered useful
The text was updated successfully, but these errors were encountered:
In short, instead of selecting a hard number of samples to like in top k, it selects the number of samples such that they don't exceed a threshold p. I think this is a popular standard technique and could potentially be added as an option for litgpt chat analogous and in addition to the top_k setting.
It would be a nice contribution.
What do you think @awaelchli@carmocca ?
Nucleus sampling (top-p sampling in HF) is a dynamic sampling strategy that "truncat[es] the unreliable tail of the probability distribution, sampling from the dynamic nucleus of tokens containing the vast majority of the probability mass.".
It can be easily implemented in the sample method like this:
I can open a PR with this add if this is considered useful
The text was updated successfully, but these errors were encountered: