Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Bug in BertScore calculation: pred target misalignment #2347

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

gxy-gxy
Copy link

@gxy-gxy gxy-gxy commented Feb 3, 2024

fix Bug in BertScore calculation: pred target misalignment

Fixes bug in BertScore cal.
This pull request addresses a bug identified in the BertScore calculation within the TextDataset class in src/torchmetrics/functional/text/helper_embedding_metric.py.
The class is designed with a preprocess function automatically sorts input text by length to optimize batch encoding efficiency. However, this behavior introduces an issue during the BertScore calculation process, as predictions (preds) and targets (targets) are initialized in separate datasets. This results in a mismatched ordering of text pairs, which is problematic given the pairwise nature of BertScore's calculation. To ensure accurate scoring, it is critical to re-align the datasets to their original order before computing the scores. The proposed fix involves ensuring that the datasets for predictions and targets are processed in a way that maintains their original pairing throughout the calculation process.

Here is the fixed code:

preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices]
target_embeddings = target_embeddings[target_loader.dataset.sorting_indices]

This change is essential for preserving the integrity of the BertScore evaluation, ensuring that each prediction is accurately compared against its corresponding target.


馃摎 Documentation preview 馃摎: https://torchmetrics--2347.org.readthedocs.build/en/2347/

@gxy-gxy
Copy link
Author

gxy-gxy commented Feb 3, 2024

Here is the test code:

from torchmetrics.text.bert import BERTScore

score_model = BERTScore(model_name_or_path='roberta-large', batch_size=2)

text1 = [ "Claim A from machine", "Claim A from machine"]
text2 = ["Claim A from machine", "Claim B"]
similarities = score_model(text1, text2)
print(similarities)

@Borda Borda added the bug / fix Something isn't working label Feb 14, 2024
@Borda Borda added this to the v1.3.x milestone Feb 19, 2024
@Borda
Copy link
Member

Borda commented Mar 15, 2024

@stancld could you help here, pls?

Copy link
Contributor

@baskrahmer baskrahmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I can reproduce the bug using your test code snippet. Although the actual metric values seem to be correct, the ordering is not always valid.

It might be nice to somehow integrate this case with the current test suite. E.g. an assertation that reversing the targets/preds also reverses the scores.

@@ -419,18 +419,12 @@ def bert_score(
preds_loader, preds_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn
)

preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices]
target_embeddings = target_embeddings[target_loader.dataset.sorting_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure the IDF factors also need to be unsorted too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @stancld

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure the IDF factors also need to be unsorted too

Thanks for your suggestion, I will fix it soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working topic: Text
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants