Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text2Topic : a new loss function ? #2605

Open
azaismarc opened this issue Apr 23, 2024 · 5 comments
Open

Text2Topic : a new loss function ? #2605

azaismarc opened this issue Apr 23, 2024 · 5 comments

Comments

@azaismarc
Copy link

Paper : https://aclanthology.org/2023.emnlp-industry.10.pdf

Hi

Recently, Booking.com shared a new architecture called Text2Topic. This architecture takes as input a text and a topic with its description, and outputs a score between 0 and 1.

To achieve this, the architecture implements a new loss function, which resembles the SoftMax loss but for binary classification.

Based on my personal experiences, the loss function also improves the embeddings of a pre-trained model like "domain adaptation".

I would like to hear opinions on the relevance of their approach and whether an implementation like the loss function would be feasible.

Thanks

@ir2718
Copy link
Contributor

ir2718 commented Apr 23, 2024

Hi,

correct me if I'm wrong but I think this is already implemented in the SoftmaxLoss class, you just have to set all the concatenation_sent parameters to true:

class SoftmaxLoss(nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        sentence_embedding_dimension: int,
        num_labels: int,                                 # set to 2 if you want binary classification
        concatenation_sent_rep: bool = True,             # set to True if you want u and v
        concatenation_sent_difference: bool = True,      # set to True if you want abs(u-v)
        concatenation_sent_multiplication: bool = False, # set to True if you want (u*v)
        loss_fct: Callable = nn.CrossEntropyLoss(),
    ):

Implementation aside, the paper reminds me a lot of approaches from SphereFace2, and exploiting hyperspherical embeddings for OOD, which are pretty cool.

@azaismarc
Copy link
Author

Yeah, the loss is very similar to SoftmaxLoss but they add a FFN with Relu activation before the classifier.

image

@ir2718
Copy link
Contributor

ir2718 commented Apr 23, 2024

Oh I see. The simplest way to do this is a bit hacky. You can just use the SoftmaxLoss and add then change the classifier variable. Something along these lines (haven't tested this though):

model = SentenceTransformer("bert-base-cased")
loss = SoftmaxLoss(
    model=model,
    sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
    num_labels=1,
    concatenation_sent_rep=True,
    concatenation_sent_difference=True,
    concatenation_sent_multiplication=True,
    loss_fct=nn.BCEWithLogitsLoss(),
)
loss.classifier = nn.Sequential(
    nn.Linear(
        in_features=loss.classifier.in_features,
        out_features=model.get_sentence_embedding_dimension(),
        bias=True
    ),
    nn.ReLU(),
    nn.Dropout(p=0.3), # not sure what value is used in the paper
    nn.Linear(
        in_features=model.get_sentence_embedding_dimension(),
        out_features=1,
        bias=True
    ),
)
loss.classifier.to(model.device)

@azaismarc
Copy link
Author

Thanks, I implemented the same way !

However, I found interesting the idea to compute a similarity matching towards classifieur like Cross-Encodeur but with Bi-Encodeur for cached some embeddings. Imo, the classifieur allows to catch more information that cosine similarity.

Maybe a new features for SetFit, or Sentence Transformers ?

@ir2718
Copy link
Contributor

ir2718 commented Apr 25, 2024

Since it's such a small change to the SoftmaxLoss, maybe it's better to just expose classifier as a parameter? I can create a PR if needed

cc @tomaarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants