Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert key_padding_mask.size(1) == src_len in 350M model #89

Open
Mrs-Hudson opened this issue May 10, 2022 · 0 comments
Open

assert key_padding_mask.size(1) == src_len in 350M model #89

Mrs-Hudson opened this issue May 10, 2022 · 0 comments
Assignees
Labels
question Further information is requested

Comments

@Mrs-Hudson
Copy link

As I am unable to run the CLI with the 350M model (#86), I am running the generation task using an adhoc script shared below.
However, I get an assert error with the same, the same data works with next token prediction ( #73)

Code

`import os

from transformers import GPT2Tokenizer
from metaseq import checkpoint_utils
import torch
import queue
import pkg_resources
import random
import shutil
import threading

from metaseq import options
from metaseq.dataclass.configs import MetaseqConfig
from metaseq.dataclass.utils import convert_namespace_to_omegaconf
from metaseq.distributed import utils as dist_utils
from metaseq.hub_utils import GeneratorInterface
from metaseq.service.queue import PriorityQueueRingShard
from metaseq.service.workers import WorkItem
from metaseq.service.constants import (
MAX_SEQ_LEN,
MAX_BATCH_TOKENS,
DEFAULT_PORT,
TOTAL_WORLD_SIZE,
CHECKPOINT_LOCAL,
CHECKPOINT_FOLDER,
LAUNCH_ARGS,
)
from metaseq.service.utils import get_my_ip, encode_fn, build_logger
from metaseq.service.responses import OAIResponse

logger = build_logger()
path = "/home/azureuser/350_model_info"

"""
$ ls path
vocab.json
merges.txt
reshard.pt
"""

tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)

paths = [os.path.join(path, "reshard.pt")]

checkpoint = checkpoint_utils.load_model_ensemble_and_task(
paths,
arg_overrides={
"vocab_filename": os.path.join(path, "vocab.json"),
"merges_filename": os.path.join(path, "merges.txt"),
}
)

model = checkpoint[0][0].eval()

forward passes

def single_batch_forward_logits(prompts):
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
input_ids = torch.cat([torch.tensor([[2]]), input_ids], dim=-1)
logits = model(input_ids)[0]
return logits

def _copy_checkpoint_cache():
if CHECKPOINT_LOCAL == CHECKPOINT_FOLDER:
# user didn't have a local SSD
return
if os.path.exists(os.path.dirname(CHECKPOINT_LOCAL)):
logger.info("Local checkpoint copy already exists, skipping copy")
else:
logger.info(
f"Making a local copy of the checkpoint. {CHECKPOINT_FOLDER} -> {CHECKPOINT_LOCAL}"
)
shutil.copytree(CHECKPOINT_FOLDER, os.path.dirname(CHECKPOINT_LOCAL))

def worker_main(cfg1: MetaseqConfig, namespace_args=None):
shutil.copytree(CHECKPOINT_FOLDER, os.path.dirname(CHECKPOINT_LOCAL))
# disable multithreading in tokenizers and torch, as different Flask threads
# may then fight for resources.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_num_threads(1)
global generator
global MODE
# make sure generations are stochastic since we have many workers
torch.manual_seed(random.randint(1, 20000))
torch.cuda.manual_seed(random.randint(1, 20000))
MODE = "worker"
cfg = cfg1
print("In worker main")
generator = GeneratorInterface(cfg)
models = generator.load_model() # noqa: F841
print("\n Model loaded \n")
logger.info(f"loaded model {cfg.distributed_training.distributed_rank}")

prompts = [
"Today is a beautiful day and I want to",
"In the city of",
"Paris is the capital of France and",
"Computers and mobile phones have taken",
"LinkedIn is a great company and I am "
]
inputs = [encode_fn(generator, p) for p in prompts]
min_tokens = [5,5,5,5,5]
max_tokens = [32,32,32,32,32]
print(inputs)
retval = generator.generate(inputs, min_tokens,max_tokens)
print(retval)

def cli_main():
"""
Hosted version of the web UI for generation.
"""
_copy_checkpoint_cache()

global port, MODE, cfg
parser = options.get_generation_parser()

# dumb defaults overriding
parser.set_defaults(lr_scheduler=None, criterion=None)
flat_launch_args = []
for s in LAUNCH_ARGS:
    flat_launch_args += s.split()
args = options.parse_args_and_arch(parser, input_args=flat_launch_args)
args.data = os.path.dirname(args.path)  # hardcode the data arg
port = DEFAULT_PORT
cfg = convert_namespace_to_omegaconf(args)
cfg.distributed_training.distributed_world_size = TOTAL_WORLD_SIZE
print("Calling main\n")
dist_utils.call_main(cfg, worker_main, namespace_args=args)

if name == "main":
cli_main()`

Stacktrace:
Traceback (most recent call last): File "inf_generate.py", line 134, in <module> cli_main() File "inf_generate.py", line 131, in cli_main dist_utils.call_main(cfg, worker_main, namespace_args=args) File "/home/azureuser/metaseq/metaseq/distributed/utils.py", line 256, in call_main return _spawn_helper(main, cfg, kwargs) File "/home/azureuser/metaseq/metaseq/distributed/utils.py", line 234, in _spawn_helper retval = distributed_main(-1, main, cfg, kwargs) File "/home/azureuser/metaseq/metaseq/distributed/utils.py", line 203, in distributed_main main(cfg, **kwargs) File "inf_generate.py", line 108, in worker_main retval = generator.generate(inputs, min_tokens,max_tokens) File "/home/azureuser/metaseq/metaseq/hub_utils.py", line 604, in generate translations = self.task.inference_step(generator, self.models, batch) File "/home/azureuser/metaseq/metaseq/tasks/language_modeling.py", line 326, in inference_step return generator.generate( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context return func(*args, **kwargs) File "/home/azureuser/metaseq/metaseq/sequence_generator.py", line 93, in generate return self._generate(sample, **kwargs) File "/home/azureuser/metaseq/metaseq/sequence_generator.py", line 286, in _generate model_out = self.model.decoder( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/metaseq/metaseq/models/transformer.py", line 639, in forward x, extra = self.extract_features( File "/home/azureuser/metaseq/metaseq/models/transformer.py", line 664, in extract_features return self.extract_features_scriptable( File "/home/azureuser/metaseq/metaseq/models/transformer.py", line 728, in extract_features_scriptable x, layer_attn, _, l_aux_i = layer( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/metaseq/metaseq/modules/transformer_layer.py", line 509, in forward x, attn = self.forward_attention( File "/home/azureuser/metaseq/metaseq/modules/transformer_layer.py", line 422, in forward_attention x, attn = self.self_attn( File "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/azureuser/metaseq/metaseq/modules/multihead_attention.py", line 331, in forward assert key_padding_mask.size(1) == src_len AssertionError

I also printed the key_padding_mask sizes and the src_len
`key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
10
10
5

key padding size is:
2
5
21
11
5`

metaseq Version (e.g., 1.0 or master): 0.0.1

PyTorch Version (e.g., 1.0): '1.10.1+cu113'

OS (e.g., Linux): Linux
NAME="Ubuntu" VERSION="18.04.6 LTS (Bionic Beaver)" ID=ubuntu ID_LIKE=debian PRETTY_NAME="Ubuntu 18.04.6 LTS" VERSION_ID="18.04" HOME_URL="https://www.ubuntu.com/" SUPPORT_URL="https://help.ubuntu.com/" BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/" PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy" VERSION_CODENAME=bionic UBUNTU_CODENAME=bionic

How you installed metaseq (pip, source): Same as setup instructions

Build command you used (if compiling from source): Same as setup instructions

Python version: 3.8.5

CUDA/cuDNN version:
(azureml_py38) azureuser@rparik4:~/metaseq$ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2021 NVIDIA Corporation Built on Sun_Feb_14_21:12:58_PST_2021 Cuda compilation tools, release 11.2, V11.2.152 Build cuda_11.2.r11.2/compiler.29618528_0

GPU models and configuration: Azure compute node with 8 gpus
Virtual machine size
Standard_ND40rs_v2 (40 cores, 672 GB RAM, 2900 GB disk)
Processing unit
GPU - 8 x NVIDIA Tesla V100

@Mrs-Hudson Mrs-Hudson added the question Further information is requested label May 10, 2022
@stephenroller stephenroller self-assigned this May 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants