Skip to content

Commit

Permalink
Fix conflicting generation args (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Dec 22, 2022
1 parent c2ac1e0 commit 9e88a9e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
16 changes: 7 additions & 9 deletions docs/source/_static/inseq.js
Expand Up @@ -61,15 +61,11 @@ function resizeHtmlExamples() {
for (const ex of examples) {
const iframe = ex.firstElementChild;
const zoom = iframe.getAttribute("scale")
const origHeight = iframe.contentWindow.document.body.scrollHeight
const origWidth = iframe.contentWindow.document.body.scrollWidth
ex.style.height = ((origHeight * zoom) + 50) + "px";
const frameHeight = origHeight / zoom
const frameWidth = origWidth / zoom
ex.style.height = ((iframe.contentWindow.document.body.scrollHeight * zoom) + 50) + "px";
// add extra 50 pixels - in reality need just a bit more
iframe.style.height = frameHeight + "px"
iframe.style.height = (iframe.contentWindow.document.body.scrollHeight / zoom) + "px"
// set the width of the iframe as the width of the iframe content
iframe.style.width = frameWidth + 'px';
iframe.style.width = (iframe.contentWindow.document.body.scrollWidth / zoom) + 'px';
iframe.style.zoom = zoom;
iframe.style.MozTransform = `scale(${zoom})`;
iframe.style.WebkitTransform = `scale(${zoom})`;
Expand All @@ -83,12 +79,14 @@ function resizeHtmlExamples() {
function onLoad() {
addIcon();
addCustomFooter();
resizeHtmlExamples();
}

window.addEventListener("load", onLoad);
window.onresize = function() {
var wwidth = $(window).width();
if(curr_width!==wwidth){
resizeHtmlExamples();
if( curr_width !== wwidth ){
window.location.reload();
curr_width = wwidth;
}
}
12 changes: 1 addition & 11 deletions inseq/models/huggingface_model.py
Expand Up @@ -65,7 +65,6 @@ def __init__(
attribution_method: Optional[str] = None,
tokenizer: Union[str, PreTrainedTokenizer, None] = None,
device: Optional[str] = None,
model_max_length: Optional[int] = 512,
**kwargs,
) -> None:
"""
Expand All @@ -81,8 +80,6 @@ def __init__(
attribution_method (str, optional): The attribution method to use.
Passing it here reduces overhead on attribute call, since it is already
initialized.
model_max_length (int, optional): The maximum length of the model. If not provided, will be inferred from
the model config.
**kwargs: additional arguments for the model and the tokenizer.
"""
super().__init__(**kwargs)
Expand Down Expand Up @@ -110,9 +107,7 @@ def __init__(
if isinstance(tokenizer, PreTrainedTokenizer):
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer, *tokenizer_inputs, model_max_length=model_max_length, **tokenizer_kwargs
)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, *tokenizer_inputs, **tokenizer_kwargs)
if self.model.config.pad_token_id is not None:
self.pad_token = self.tokenizer.convert_ids_to_tokens(self.model.config.pad_token_id)
self.tokenizer.pad_token = self.pad_token
Expand All @@ -122,7 +117,6 @@ def __init__(
if self.tokenizer.unk_token_id is None:
self.tokenizer.unk_token_id = self.tokenizer.pad_token_id
self.embed_scale = 1.0
self.model_max_length = model_max_length
self.encoder_int_embeds = None
self.decoder_int_embeds = None
self.is_encoder_decoder = self.model.config.is_encoder_decoder
Expand Down Expand Up @@ -167,7 +161,6 @@ def generate(
self,
inputs: Union[TextInput, BatchEncoding],
return_generation_output: bool = False,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> Union[List[str], Tuple[List[str], ModelOutput]]:
"""Wrapper of model.generate to handle tokenization and decoding.
Expand All @@ -186,14 +179,11 @@ def generate(
isinstance(inputs, list) and len(inputs) > 0 and all([isinstance(x, str) for x in inputs])
):
inputs = self.encode(inputs)
if max_new_tokens is None:
max_new_tokens = self.model_max_length - inputs.input_ids.shape[-1]
inputs = inputs.to(self.device)
generation_out = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict_in_generate=True,
max_new_tokens=max_new_tokens,
**kwargs,
)
texts = self.tokenizer.batch_decode(
Expand Down

0 comments on commit 9e88a9e

Please sign in to comment.