Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Make saliency interpreter GPU compatible (#5656)
Browse files Browse the repository at this point in the history
* make saliency interpreter GPU compatible

* traitlets

* traitlets
  • Loading branch information
AkshitaB committed Jun 20, 2022
1 parent ea4a53c commit df9d7ca
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _aggregate_token_embeddings(
embeddings_list: List[torch.Tensor], token_offsets: List[torch.Tensor]
) -> List[numpy.ndarray]:
if len(token_offsets) == 0:
return [embeddings.numpy() for embeddings in embeddings_list]
return [embeddings.detach().cpu().numpy() for embeddings in embeddings_list]
aggregated_embeddings = []
# NOTE: This is assuming that embeddings and offsets come in the same order, which may not
# be true. But, the intersection of using multiple TextFields with mismatched indexers is
Expand All @@ -60,5 +60,5 @@ def _aggregate_token_embeddings(

# All the places where the span length is zero, write in zeros.
embeddings[(span_embeddings_len == 0).expand(embeddings.shape)] = 0
aggregated_embeddings.append(embeddings.numpy())
aggregated_embeddings.append(embeddings.detach().cpu().numpy())
return aggregated_embeddings

0 comments on commit df9d7ca

Please sign in to comment.