diff --git a/README.md b/README.md index 416702f20b072..fe1d0a990d117 100644 --- a/README.md +++ b/README.md @@ -320,6 +320,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. +1. **[Jukebox](https://huggingface.co/docs/transformers/main/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. 1. **[LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. diff --git a/README_es.md b/README_es.md index 7aa72c2f3e343..2d4028ff91e92 100644 --- a/README_es.md +++ b/README_es.md @@ -320,6 +320,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. +1. **[Jukebox](https://huggingface.co/docs/transformers/main/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. 1. **[LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. diff --git a/README_ja.md b/README_ja.md index 6f5a8d94a200b..9627b9fb1546d 100644 --- a/README_ja.md +++ b/README_ja.md @@ -355,6 +355,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. +1. **[Jukebox](https://huggingface.co/docs/transformers/main/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. 1. **[LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. diff --git a/README_ko.md b/README_ko.md index 8d65e33a744f9..e01f05a28b66c 100644 --- a/README_ko.md +++ b/README_ko.md @@ -270,6 +270,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. +1. **[Jukebox](https://huggingface.co/docs/transformers/main/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. 1. **[LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. diff --git a/README_zh-hans.md b/README_zh-hans.md index 6b053d7bfe8a2..36860c9598d48 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -294,6 +294,7 @@ conda install -c huggingface transformers 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (来自 Facebook) 伴随论文 [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) 由 Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed 发布。 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (来自 Berkeley) 伴随论文 [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) 由 Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer 发布。 1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (来自 OpenAI) 伴随论文 [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) 由 Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever 发布。 +1. **[Jukebox](https://huggingface.co/docs/transformers/main/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) 由 Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou 发布。 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) 由 Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou 发布。 1. **[LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) 由 Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 7c1a31eb7fa6f..d5a965f5b9167 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -306,6 +306,7 @@ conda install -c huggingface transformers 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. +1. **[Jukebox](https://huggingface.co/docs/transformers/main/model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. 1. **[LayoutLMv3](https://huggingface.co/docs/transformers/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f18e54eea0d75..814a570d88976 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -275,6 +275,8 @@ title: HerBERT - local: model_doc/ibert title: I-BERT + - local: model_doc/jukebox + title: Jukebox - local: model_doc/layoutlm title: LayoutLM - local: model_doc/led diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 14856033c85c5..fa6ecbfc3bf2c 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -108,6 +108,7 @@ The documentation is organized into five sections: 1. **[Hubert](model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. +1. **[Jukebox](model_doc/jukebox)** (from OpenAI) released with the paper [Jukebox: A Generative Model for Music](https://arxiv.org/pdf/2005.00341.pdf) by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever. 1. **[LayoutLM](model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. 1. **[LayoutLMv3](model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. @@ -263,6 +264,7 @@ Flax), PyTorch, and/or TensorFlow. | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ | +| Jukebox | ✅ | ❌ | ✅ | ❌ | ❌ | | LayoutLM | ✅ | ✅ | ✅ | ✅ | ❌ | | LayoutLMv2 | ✅ | ✅ | ✅ | ❌ | ❌ | | LayoutLMv3 | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/jukebox.mdx b/docs/source/en/model_doc/jukebox.mdx new file mode 100644 index 0000000000000..860fb8fc3f67b --- /dev/null +++ b/docs/source/en/model_doc/jukebox.mdx @@ -0,0 +1,79 @@ + +# Jukebox + +## Overview + +The Jukebox model was proposed in [Jukebox: A generative model for music](https://arxiv.org/pdf/2005.00341.pdf) +by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, +Ilya Sutskever. It introduces a generative music model which can produce minute long samples that can be conditionned on +an artist, genres and lyrics. + +The abstract from the paper is the following: + +*We introduce Jukebox, a model that generates music with singing in the raw audio domain. We tackle the long context of raw audio using a multiscale VQ-VAE to compress it to discrete codes, and modeling those using autoregressive Transformers. We show that the combined model at scale can generate high-fidelity and diverse songs with coherence up to multiple minutes. We can condition on artist and genre to steer the musical and vocal style, and on unaligned lyrics to make the singing more controllable. We are releasing thousands of non cherry-picked samples, along with model weights and code.* + +As shown on the following figure, Jukebox is made of 3 `priors` which are decoder only models. They follow the architecture described in [Generating Long Sequences with Sparse Transformers](https://arxiv.org/abs/1904.10509), modified to support longer context length. +First, a autoencoder is used to encode the text lyrics. Next, the first (also called `top_prior`) prior attends to the last hidden states extracted from the lyrics encoder. The priors are linked to the previous priors respectively via an `AudioConditionner` module. The`AudioConditioner` upsamples the outputs of the previous prior to raw tokens at a certain audio frame per second resolution. +The metadata such as *artist, genre and timing* are passed to each prior, in the form of a start token and positionnal embedding for the timing data. The hidden states are mapped to the closest codebook vector from the VQVAE in order to convert them to raw audio. + +![JukeboxModel](https://gist.githubusercontent.com/ArthurZucker/92c1acaae62ebf1b6a951710bdd8b6af/raw/c9c517bf4eff61393f6c7dec9366ef02bdd059a3/jukebox.svg) + +Tips: +- This model only supports inference. This is for a few reasons, mostly because it requires a crazy amount of memory to train. Feel free to open a PR and add what's missing to have a full integration with the hugging face traineer! +- This model is very slow, and takes 8h to generate a minute long audio using the 5b top prior on a V100 GPU. In order automaticallay handle the device on which the model should execute, use `accelerate`. +- Contrary to the paper, the order of the priors goes from `0` to `1` as it felt more intuitive : we sample starting from `0`. +- Primed sampling (conditionning the sampling on raw audio) requires more memory than ancestral sampling and should be used with `fp16` set to `True`. + +This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). +The original code can be found [here](https://github.com/openai/jukebox). + +## JukeboxConfig + +[[autodoc]] JukeboxConfig + +## JukeboxPriorConfig + +[[autodoc]] JukeboxPriorConfig + +## JukeboxVQVAEConfig + +[[autodoc]] JukeboxVQVAEConfig + +## JukeboxTokenizer + +[[autodoc]] JukeboxTokenizer + - save_vocabulary + +## JukeboxModel + +[[autodoc]] JukeboxModel + - ancestral_sample + - primed_sample + - continue_sample + - upsample + - _sample + + +## JukeboxPrior + +[[autodoc]] JukeboxPrior + - sample + - forward + + +## JukeboxVQVAE + +[[autodoc]] JukeboxVQVAE + - forward + - encode + - decode diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c07836075b3f2..6dac658d4faa6 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -248,6 +248,13 @@ "models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"], "models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], "models.imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"], + "models.jukebox": [ + "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP", + "JukeboxConfig", + "JukeboxPriorConfig", + "JukeboxTokenizer", + "JukeboxVQVAEConfig", + ], "models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"], "models.layoutlmv2": [ "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -1480,6 +1487,15 @@ "load_tf_weights_in_imagegpt", ] ) + _import_structure["models.jukebox"].extend( + [ + "JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST", + "JukeboxModel", + "JukeboxPreTrainedModel", + "JukeboxVQVAE", + "JukeboxPrior", + ] + ) _import_structure["models.layoutlm"].extend( [ "LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3343,6 +3359,13 @@ from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig from .models.imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig + from .models.jukebox import ( + JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP, + JukeboxConfig, + JukeboxPriorConfig, + JukeboxTokenizer, + JukeboxVQVAEConfig, + ) from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer from .models.layoutlmv2 import ( LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -4350,6 +4373,13 @@ ImageGPTPreTrainedModel, load_tf_weights_in_imagegpt, ) + from .models.jukebox import ( + JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST, + JukeboxModel, + JukeboxPreTrainedModel, + JukeboxPrior, + JukeboxVQVAE, + ) from .models.layoutlm import ( LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, LayoutLMForMaskedLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 1c5919cfeed90..a325e95c11014 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -78,6 +78,7 @@ hubert, ibert, imagegpt, + jukebox, layoutlm, layoutlmv2, layoutlmv3, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e92faa4d041a5..68f6bdc7d54d5 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -81,6 +81,7 @@ ("hubert", "HubertConfig"), ("ibert", "IBertConfig"), ("imagegpt", "ImageGPTConfig"), + ("jukebox", "JukeboxConfig"), ("layoutlm", "LayoutLMConfig"), ("layoutlmv2", "LayoutLMv2Config"), ("layoutlmv3", "LayoutLMv3Config"), @@ -221,6 +222,7 @@ ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("jukebox", "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv3", "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -363,6 +365,7 @@ ("hubert", "Hubert"), ("ibert", "I-BERT"), ("imagegpt", "ImageGPT"), + ("jukebox", "Jukebox"), ("layoutlm", "LayoutLM"), ("layoutlmv2", "LayoutLMv2"), ("layoutlmv3", "LayoutLMv3"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 93355ddb56497..caaaeeeb22c5e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -80,6 +80,7 @@ ("hubert", "HubertModel"), ("ibert", "IBertModel"), ("imagegpt", "ImageGPTModel"), + ("jukebox", "JukeboxModel"), ("layoutlm", "LayoutLMModel"), ("layoutlmv2", "LayoutLMv2Model"), ("layoutlmv3", "LayoutLMv3Model"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index f71366d8d1ed1..571f83b330212 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -143,6 +143,7 @@ ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), ("hubert", ("Wav2Vec2CTCTokenizer", None)), ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("jukebox", ("JukeboxTokenizer", None)), ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/jukebox/__init__.py b/src/transformers/models/jukebox/__init__.py new file mode 100644 index 0000000000000..774e06bc3409b --- /dev/null +++ b/src/transformers/models/jukebox/__init__.py @@ -0,0 +1,74 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_jukebox": [ + "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP", + "JukeboxConfig", + "JukeboxPriorConfig", + "JukeboxVQVAEConfig", + ], + "tokenization_jukebox": ["JukeboxTokenizer"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_jukebox"] = [ + "JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST", + "JukeboxModel", + "JukeboxPreTrainedModel", + "JukeboxVQVAE", + "JukeboxPrior", + ] + +if TYPE_CHECKING: + from .configuration_jukebox import ( + JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP, + JukeboxConfig, + JukeboxPriorConfig, + JukeboxVQVAEConfig, + ) + from .tokenization_jukebox import JukeboxTokenizer + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_jukebox import ( + JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST, + JukeboxModel, + JukeboxPreTrainedModel, + JukeboxPrior, + JukeboxVQVAE, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/jukebox/configuration_jukebox.py b/src/transformers/models/jukebox/configuration_jukebox.py new file mode 100644 index 0000000000000..6ce345a8578e2 --- /dev/null +++ b/src/transformers/models/jukebox/configuration_jukebox.py @@ -0,0 +1,639 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Jukebox configuration""" + +import copy +import os +from typing import List, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "openai/jukebox-5b-lyrics": "https://huggingface.co/openai/jukebox-5b-lyrics/blob/main/config.json", + "openai/jukebox-1b-lyrics": "https://huggingface.co/openai/jukebox-1b-lyrics/blob/main/config.json", +} + +_LARGE_ATTENTION = [ + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", +] +_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"] +_FullDenseAttention = ["dense_attention"] +_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"] + + +def full_dense_attention(layer): + return _FullDenseAttention[0] + + +def raw_column_previous_row_attention(layer): + return _RawColumnPreviousRowAttention[layer % 3] + + +def large_separated_enc_dec_w_lyrics(layer): + return _LARGE_ATTENTION[layer % 79] + + +def enc_dec_with_lyrics(layer): + if layer % 16 == 15: + return _PrimePrimeDenseAttention[layer % 3] + return _RawColumnPreviousRowAttention[layer % 3] + + +ATTENTION_PATTERNS = { + "full_dense_attention": full_dense_attention, + "raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn + "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics + "enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics +} + + +class JukeboxPriorConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a + `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the top level prior from the + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox + -1b-lyrics) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + + Args: + act_fn (`str`, *optional*, defaults to `"quick_gelu"`): + Activation function. + alignment_head (`int`, *optional*, defaults to 2): + Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio + alignment + alignment_layer (`int`, *optional*, defaults to 68): + Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the + lyric to audio alignment + attention_multiplier (`float`, *optional*, defaults to 0.25): + Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that + 0.25*width of the model will be used. + attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`): + Which attention pattern to use for the decoder/ + attn_dropout (`int`, *optional*, defaults to 0): + Dropout probability for the post-attention layer dropout in the decoder. + attn_res_scale (`bool`, *optional*, defaults to `False`): + Whether or not to scale the residuals in the attention conditioner block. + blocks (`int`, *optional*, defaults to 64): + Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len // + blocks]` in the `JukeboxAttention` layer. + conv_res_scale (`int`, *optional*): + Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a + conditioner, the default value is to None and should not be modified. + num_layers (`int`, *optional*, defaults to 72): + Number of layers of the transformer architecture. + emb_dropout (`int`, *optional*, defaults to 0): + Embedding dropout used in the lyric decoder. + encoder_config (`JukeboxPriorConfig`, *optional*) : + Configuration of the encoder which models the prior on the lyrics. + encoder_loss_fraction (`float`, *optional*, defaults to 0.4): + Multiplication factor used in front of the lyric encoder loss. + hidden_size (`int`, *optional*, defaults to 2048): + Hidden dimension of the attention layers. + init_scale (`float`, *optional*, defaults to 0.2): + Initialization scales for the prior modules. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is + greater than 0, the `encoder` args should be specified for the lyric encoding. + mask (`bool`, *optional*, defaults to `False`): + Whether or not to mask the previous positions in the attention. + max_duration (`int`, *optional*, defaults to 600): + Maximum supported duration of the generated song in seconds. + max_nb_genres (`int`, *optional*, defaults to 1): + Maximum number of genres that can be used to condition the model. + merged_decoder (`bool`, *optional*, defaults to `True`): + Whether or not the decoder and the encoder inputs are merged. This is used for the separated + encoder-decoder architecture + metadata_conditioning (`bool`, *optional*, defaults to `True)`: + Whether or not to condition on the artist and genre metadata. + metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`): + Number of genres and the number of artists that were used to train the embedding layers of the prior + models. + min_duration (`int`, *optional*, defaults to 0): + Minimum duration of the generated audio on which the model was trained. + mlp_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of + the model will be used. + music_vocab_size (`int`, *optional*, defaults to 2048): + Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`. + n_ctx (`int`, *optional*, defaults to 6144): + Number of context tokens for each prior. The context tokens are the music tokens that are attended to when + generating music tokens. + n_heads (`int`, *optional*, defaults to 2): + Number of attention heads. + nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384): + Number of lyric tokens that are used when sampling a single window of length `n_ctx` + res_conv_depth (`int`, *optional*, defaults to 3): + Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the + `JukeboxMusicTokenConditioner`. + res_conv_width (`int`, *optional*, defaults to 128): + Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the + `JukeboxMusicTokenConditioner`. + res_convolution_multiplier (`int`, *optional*, defaults to 1): + Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`. + res_dilation_cycle (`int`, *optional*): + Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the + corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level + tokens. + res_dilation_growth_rate (`int`, *optional*, defaults to 1): + Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner` + res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`): + Downsampling rates used in the audio conditioning network + res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + Striding used in the audio conditioning network + resid_dropout (`int`, *optional*, defaults to 0): + Residual dropout used in the attention pattern. + sampling_rate (`int`, *optional*, defaults to 44100): + Sampling rate used for training. + spread (`int`, *optional*): + Spread used in the `summary_spread_attention` pattern + timing_dims (`int`, *optional*, defaults to 64): + Dimension of the timing embedding. + zero_out (`bool`, *optional*, defaults to `False`): + Whether or not to zero out convolution weights when initializing. + """ + + model_type = "jukebox_prior" + attribute_map = { + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + } + + def __init__( + self, + act_fn="quick_gelu", + level=0, + alignment_head=2, + alignment_layer=68, + attention_multiplier=0.25, + attention_pattern="enc_dec_with_lyrics", + attn_dropout=0, + attn_res_scale=False, + blocks=64, + conv_res_scale=None, + num_layers=72, + emb_dropout=0, + encoder_config=None, + encoder_loss_fraction=0.4, + hidden_size=2048, + init_scale=0.2, + is_encoder_decoder=True, + lyric_vocab_size=80, + mask=False, + max_duration=600, + max_nb_genres=1, + merged_decoder=True, + metadata_conditioning=True, + metadata_dims=[604, 7898], + min_duration=0, + mlp_multiplier=1.0, + music_vocab_size=2048, + n_ctx=6144, + n_heads=2, + nb_relevant_lyric_tokens=384, + res_conv_depth=3, + res_conv_width=128, + res_convolution_multiplier=1, + res_dilation_cycle=None, + res_dilation_growth_rate=1, + res_downs_t=[3, 2, 2], + res_strides_t=[2, 2, 2], + resid_dropout=0, + sampling_rate=44100, + spread=None, + timing_dims=64, + zero_out=False, + **kwargs + ): + self.act_fn = act_fn + self.alignment_head = alignment_head + self.alignment_layer = alignment_layer + self.attention_multiplier = attention_multiplier + self.attention_pattern = attention_pattern + self.attn_dropout = attn_dropout + self.attn_res_scale = attn_res_scale + self.blocks = blocks + self.conv_res_scale = conv_res_scale + self.num_layers = num_layers + self.emb_dropout = emb_dropout + self.music_vocab_size = music_vocab_size + if encoder_config is not None: + self.encoder_config = JukeboxPriorConfig(**encoder_config) + else: + self.encoder_config = None + self.encoder_loss_fraction = encoder_loss_fraction + self.init_scale = init_scale + self.is_encoder_decoder = is_encoder_decoder + self.lyric_vocab_size = lyric_vocab_size + self.level = level + self.mask = mask + self.max_duration = max_duration + self.max_nb_genres = max_nb_genres + self.merged_decoder = merged_decoder + self.metadata_conditioning = metadata_conditioning + self.metadata_dims = metadata_dims + self.min_duration = min_duration + self.mlp_multiplier = mlp_multiplier + self.n_ctx = n_ctx + self.n_heads = n_heads + self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens + self.res_conv_depth = res_conv_depth + self.res_conv_width = res_conv_width + self.res_convolution_multiplier = res_convolution_multiplier + self.res_dilation_cycle = res_dilation_cycle + self.res_dilation_growth_rate = res_dilation_growth_rate + self.res_downs_t = res_downs_t + self.res_strides_t = res_strides_t + self.resid_dropout = resid_dropout + self.sampling_rate = sampling_rate + self.spread = spread + self.timing_dims = timing_dims + self.hidden_size = hidden_size + self.zero_out = zero_out + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs + ) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the prior config dict if we are loading from JukeboxConfig + if config_dict.get("model_type") == "jukebox": + config_dict = config_dict[f"prior_{level}"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["encoder_config"] = self.encoder_config.to_dict() if self.encoder_config is not None else None + output["model_type"] = self.__class__.model_type + return output + + +class JukeboxVQVAEConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a + `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VQVAE from + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function of the model. + nb_discrete_codes (`int`, *optional*, defaults to 2048): + Number of codes of the VQVAE. + commit (`float`, *optional*, defaults to 0.02): + Commit loss multiplier. + conv_input_shape (`int`, *optional*, defaults to 1): + Number of audio channels. + conv_res_scale (`bool`, *optional*, defaults to `False`): + Whether or not to scale the residuals of the `JukeboxResConv1DBlock`. + embed_dim (`int`, *optional*, defaults to 64): + Embedding dimension of the codebook vectors. + hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`): + Fraction of non-intersecting window used when continuing the sampling process. + levels (`int`, *optional*, defaults to 3): + Number of hierarchical levels that used in the VQVAE. + lmu (`float`, *optional*, defaults to 0.99): + Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1 + of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf) + multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`): + Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth` + res_conv_depth (`int`, *optional*, defaults to 4): + Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level. + res_conv_width (`int`, *optional*, defaults to 32): + Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level. + res_convolution_multiplier (`int`, *optional*, defaults to 1): + Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`. + res_dilation_cycle (`int`, *optional*): + Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth + reduced by a power of `res_dilation_cycle`. + res_dilation_growth_rate (`int`, *optional*, defaults to 3): + Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth) + res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`): + Downsampling rate for each level of the hierarchical VQ-VAE. + res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + Stride used for each level of the hierarchical VQ-VAE. + sample_length (`int`, *optional*, defaults to 1058304): + Provides the max input shape of the VQVAE. Is used to compute the input shape of each level. + init_scale (`float`, *optional*, defaults to 0.2): + Initialization scale. + zero_out (`bool`, *optional*, defaults to `False`): + Whether or not to zero out convolution weights when initializing. + """ + + model_type = "jukebox_vqvae" + + def __init__( + self, + act_fn="relu", + nb_discrete_codes=2048, + commit=0.02, + conv_input_shape=1, + conv_res_scale=False, + embed_dim=64, + hop_fraction=[0.125, 0.5, 0.5], + levels=3, + lmu=0.99, + multipliers=[2, 1, 1], + res_conv_depth=4, + res_conv_width=32, + res_convolution_multiplier=1, + res_dilation_cycle=None, + res_dilation_growth_rate=3, + res_downs_t=[3, 2, 2], + res_strides_t=[2, 2, 2], + sample_length=1058304, + init_scale=0.2, + zero_out=False, + **kwargs + ): + self.hop_fraction = hop_fraction + self.conv_input_shape = conv_input_shape + self.sample_length = sample_length + + # VQVAE parameters (all used) + self.levels = levels + self.embed_dim = embed_dim + self.nb_discrete_codes = nb_discrete_codes + self.res_conv_width = res_conv_width + self.res_conv_depth = res_conv_depth + self.res_convolution_multiplier = res_convolution_multiplier + self.res_dilation_growth_rate = res_dilation_growth_rate + self.res_dilation_cycle = res_dilation_cycle + self.multipliers = multipliers + self.res_downs_t = res_downs_t + self.res_strides_t = res_strides_t + self.lmu = lmu + self.commit = commit + self.conv_res_scale = conv_res_scale + self.act_fn = act_fn + self.init_scale = init_scale + self.zero_out = zero_out + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "jukebox": + config_dict = config_dict["vqvae_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class JukeboxConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxModel`]. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will + yield a similar configuration to that of + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture. + + + The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling = + (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256 + to get the second level codes. This is mostly true for training the top level prior and the upsamplers. + + Args: + vqvae_config (`JukeboxVQVAEConfig`, *optional*): + Configuration for the `JukeboxVQVAE` model. + prior_config_list (`List[JukeboxPriorConfig]`, *optional*): + List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors. + nb_priors (`int`, *optional*, defaults to 3): + Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive + (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were + trained using a top prior and 2 upsampler priors. + sampling_rate (`int`, *optional*, defaults to 44100): + Sampling rate of the raw audio. + timing_dims (`int`, *optional*, defaults to 64): + Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding + layer. The timing embedding layer converts the absolute and relative position in the currently sampled + audio to a tensor of length `timing_dims` that will be added to the music tokens. + min_duration (`int`, *optional*, defaults to 0): + Minimum duration of the audios to generate + max_duration (`float`, *optional*, defaults to 600.0): + Maximum duration of the audios to generate + max_nb_genres (`int`, *optional*, defaults to 5): + Maximum number of genres that can be used to condition a single sample. + metadata_conditioning (`bool`, *optional*, defaults to `True`): + Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum + duration. + init_std (`float`, *optional*, defaults to 0.2): + Standard deviation used to initial the model. + + Example: + + ```python + >>> from transformers import JukeboxModel, JukeboxConfig + + >>> # Initializing a Jukebox configuration + >>> configuration = JukeboxConfig() + + >>> # Initializing a model from the configuration + >>> model = JukeboxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "jukebox" + is_composition = True + + def __init__( + self, + vqvae_config=None, + prior_config_list=None, + nb_priors=3, + sampling_rate=44100, + timing_dims=64, + min_duration=0, + max_duration=600.0, + max_nb_genres=5, + metadata_conditioning=True, + init_std=0.2, + **kwargs, + ): + + if vqvae_config is None: + vqvae_config = {} + logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.") + + self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config) + if prior_config_list is not None: + self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list] + else: + self.prior_configs = [] + for prior_idx in range(nb_priors): + prior_config = kwargs.pop(f"prior_{prior_idx}", None) + if prior_config is None: + prior_config = {} + logger.info( + f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default" + " values." + ) + self.prior_configs.append(JukeboxPriorConfig(**prior_config)) + + self.hop_fraction = self.vqvae_config.hop_fraction + + self.init_std = init_std + self.nb_priors = nb_priors + + # Metadata conditioning + self.max_nb_genres = max_nb_genres + self.sampling_rate = sampling_rate + self.timing_dims = timing_dims + self.min_duration = min_duration + self.max_duration = max_duration + self.metadata_conditioning = metadata_conditioning + + super().__init__(**kwargs) + + @classmethod + def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs): + r""" + Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`JukeboxConfig`]: An instance of a configuration object + """ + prior_config_list = [config.to_dict() for config in prior_configs] + return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + for i, config in enumerate(output.pop("prior_configs")): + output[f"prior_{i}"] = config.to_dict() + + output["vqvae_config"] = self.vqvae_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/jukebox/convert_jukebox.py b/src/transformers/models/jukebox/convert_jukebox.py new file mode 100644 index 0000000000000..c8d0831e53f3d --- /dev/null +++ b/src/transformers/models/jukebox/convert_jukebox.py @@ -0,0 +1,280 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Jukebox checkpoints""" + +import argparse +import json +import os +from pathlib import Path + +import torch + +import requests +from transformers import JukeboxConfig, JukeboxModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +PREFIX = "https://openaipublic.azureedge.net/jukebox/models/" +MODEL_MAPPING = { + "jukebox-1b-lyrics": [ + "5b/vqvae.pth.tar", + "5b/prior_level_0.pth.tar", + "5b/prior_level_1.pth.tar", + "1b_lyrics/prior_level_2.pth.tar", + ], + "jukebox-5b-lyrics": [ + "5b/vqvae.pth.tar", + "5b/prior_level_0.pth.tar", + "5b/prior_level_1.pth.tar", + "5b_lyrics/prior_level_2.pth.tar", + ], +} + + +def replace_key(key): + if key.endswith(".model.1.bias") and len(key.split(".")) > 10: + key = key.replace(".model.1.bias", ".conv1d_1.bias") + elif key.endswith(".model.1.weight") and len(key.split(".")) > 10: + key = key.replace(".model.1.weight", ".conv1d_1.weight") + elif key.endswith(".model.3.bias") and len(key.split(".")) > 10: + key = key.replace(".model.3.bias", ".conv1d_2.bias") + elif key.endswith(".model.3.weight") and len(key.split(".")) > 10: + key = key.replace(".model.3.weight", ".conv1d_2.weight") + + if "conditioner_blocks.0." in key: + key = key.replace("conditioner_blocks.0", "conditioner_blocks") + + if "prime_prior" in key: + key = key.replace("prime_prior", "encoder") + + if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key: + key = key.replace(".emb.", ".") + + if key.endswith("k"): # replace vqvae.X.k with vqvae.X.codebook + return key.replace(".k", ".codebook") + if "y_emb." in key: + return key.replace("y_emb.", "metadata_embedding.") + + if "x_emb.emb." in key: + key = key.replace("0.x_emb.emb", "embed_tokens") + + if "prime_state_ln" in key: + return key.replace("prime_state_ln", "encoder.final_layer_norm") + if ".ln" in key: + return key.replace(".ln", ".layer_norm") + if "_ln" in key: + return key.replace("_ln", "_layer_norm") + + if "prime_state_proj" in key: + return key.replace("prime_state_proj", "encoder.proj_in") + if "prime_x_out" in key: + return key.replace("prime_x_out", "encoder.lm_head") + if "prior.x_out" in key: + return key.replace("x_out", "fc_proj_out") + if "x_emb" in key: + return key.replace("x_emb", "embed_tokens") + + return key + + +def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping): + new_dict = {} + import re + + re_encoder_block_conv_in = re.compile("encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)") + re_encoder_block_resnet = re.compile( + "encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)" + ) + re_encoder_block_proj_out = re.compile("encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)") + + re_decoder_block_conv_out = re.compile("decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)") + re_decoder_block_resnet = re.compile( + "decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)" + ) + re_decoder_block_proj_in = re.compile("decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)") + + re_prior_cond_conv_out = re.compile("conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)") + re_prior_cond_resnet = re.compile( + "conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)" + ) + re_prior_cond_proj_in = re.compile("conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)") + + for original_key, value in state_dict.items(): + + # rename vqvae.encoder keys + if re_encoder_block_conv_in.fullmatch(original_key): + regex_match = re_encoder_block_conv_in.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) + re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}.{groups[-1]}" + key = re_encoder_block_conv_in.sub(re_new_key, original_key) + + elif re_encoder_block_resnet.fullmatch(original_key): + regex_match = re_encoder_block_resnet.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) + conv_index = {"1": 1, "3": 2}[groups[-2]] + prefix = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}." + resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}" + re_new_key = prefix + resnet_block + key = re_encoder_block_resnet.sub(re_new_key, original_key) + + elif re_encoder_block_proj_out.fullmatch(original_key): + regex_match = re_encoder_block_proj_out.match(original_key) + groups = regex_match.groups() + re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.proj_out.{groups[-1]}" + key = re_encoder_block_proj_out.sub(re_new_key, original_key) + + # rename vqvae.decoder keys + elif re_decoder_block_conv_out.fullmatch(original_key): + regex_match = re_decoder_block_conv_out.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) - 2 + re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}.{groups[-1]}" + key = re_decoder_block_conv_out.sub(re_new_key, original_key) + + elif re_decoder_block_resnet.fullmatch(original_key): + regex_match = re_decoder_block_resnet.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) - 2 + conv_index = {"1": 1, "3": 2}[groups[-2]] + prefix = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}." + resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}" + re_new_key = prefix + resnet_block + key = re_decoder_block_resnet.sub(re_new_key, original_key) + + elif re_decoder_block_proj_in.fullmatch(original_key): + regex_match = re_decoder_block_proj_in.match(original_key) + groups = regex_match.groups() + re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.proj_in.{groups[-1]}" + key = re_decoder_block_proj_in.sub(re_new_key, original_key) + + # rename prior cond.model to upsampler.upsample_block and resnet + elif re_prior_cond_conv_out.fullmatch(original_key): + regex_match = re_prior_cond_conv_out.match(original_key) + groups = regex_match.groups() + block_index = int(groups[1]) * 2 + int(groups[2]) - 2 + re_new_key = f"conditioner_blocks.upsampler.upsample_block.{block_index}.{groups[-1]}" + key = re_prior_cond_conv_out.sub(re_new_key, original_key) + + elif re_prior_cond_resnet.fullmatch(original_key): + regex_match = re_prior_cond_resnet.match(original_key) + groups = regex_match.groups() + block_index = int(groups[1]) * 2 + int(groups[2]) - 2 + conv_index = {"1": 1, "3": 2}[groups[-2]] + prefix = f"conditioner_blocks.upsampler.upsample_block.{block_index}." + resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}" + re_new_key = prefix + resnet_block + key = re_prior_cond_resnet.sub(re_new_key, original_key) + + elif re_prior_cond_proj_in.fullmatch(original_key): + regex_match = re_prior_cond_proj_in.match(original_key) + groups = regex_match.groups() + re_new_key = f"conditioner_blocks.upsampler.proj_in.{groups[-1]}" + key = re_prior_cond_proj_in.sub(re_new_key, original_key) + + # keep original key + else: + key = original_key + + key = replace_key(key) + + if f"{key_prefix}.{key}" not in model_state_dict or key is None: + print(f"failed converting {original_key} to {key}, does not match") + + # handle missmatched shape + elif value.shape != model_state_dict[f"{key_prefix}.{key}"].shape: + val = model_state_dict[f"{key_prefix}.{key}"] + print(f"{original_key}-> {key} : \nshape {val.shape} and { value.shape}, do not match") + key = original_key + + mapping[key] = original_key + new_dict[key] = value + + return new_dict + + +@torch.no_grad() +def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None): + """ + Copy/paste/tweak model's weights to our Jukebox structure. + """ + for file in MODEL_MAPPING[model_name]: + if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"): + r = requests.get(f"{PREFIX}{file}", allow_redirects=True) + os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True) + open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content) + + model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]] + + config = JukeboxConfig.from_pretrained(model_name) + model = JukeboxModel(config) + + weight_dict = [] + mapping = {} + for i, dict_name in enumerate(model_to_convert): + old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}")["model"] + + new_dic = {} + for k in old_dic.keys(): + if k.endswith(".b"): + new_dic[k.replace("b", "bias")] = old_dic[k] + elif k.endswith(".w"): + new_dic[k.replace("w", "weight")] = old_dic[k] + elif "level_2" not in dict_name and "cond.model." in k: + new_dic[k.replace(".blocks.", ".model.")] = old_dic[k] + else: + new_dic[k] = old_dic[k] + + key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}" + new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping) + weight_dict.append(new_dic) + + vqvae_state_dict = weight_dict.pop(0) + model.vqvae.load_state_dict(vqvae_state_dict) + for i in range(len(weight_dict)): + model.priors[i].load_state_dict(weight_dict[2 - i]) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile: + json.dump(mapping, txtfile) + + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + return weight_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="jukebox-5b-lyrics", + type=str, + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="jukebox-5b-lyrics-converted", + type=str, + help="Path to the output PyTorch model directory.", + ) + args = parser.parse_args() + convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/jukebox/modeling_jukebox.py b/src/transformers/models/jukebox/modeling_jukebox.py new file mode 100755 index 0000000000000..956260a25c685 --- /dev/null +++ b/src/transformers/models/jukebox/modeling_jukebox.py @@ -0,0 +1,2667 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Jukebox model.""" + +import math +import os +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import LayerNorm as FusedLayerNorm + +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, logging +from ...utils.logging import tqdm +from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig + + +logger = logging.get_logger(__name__) + +JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "openai/jukebox-1b-lyrics", + "openai/jukebox-5b-lyrics", + # See all Jukebox models at https://huggingface.co/models?filter=jukebox +] + + +def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + + Args: + logits (`torch.Tensor`): + logits distribution shape (vocabulary size) + top_k (`int`, *optional*, defaults to 0): + When `top_k >0` keep only top key tokens with highest probability (top-k filtering). + top_p (`int`, *optional*, defaults to 0): + When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering). + """ + logits = logits.clone() + top_k = min(top_k, logits.size(-1)) # Safety check + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # indices_to_remove = sorted_indices[sorted_indices_to_remove] + indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration): + """ + Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be + returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the + midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on + the most relevant tokens (in time) for the sequence. + + Args: + full_tokens (`List[int]`): + List containing the token ids of the entire lyrics. + total_length (`int`): + Total expected length of the music (not all of it is generated, see duration), in samples. + offset (`int`): + Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into + account + duration (`int`): + Expected duration of the generated music, in samples. The duration has to be smaller than the total length, + which represent the overall length of the signal, + """ + full_tokens = full_tokens[0] + if len(full_tokens) < max_n_lyric_tokens: + tokens = torch.cat( + [torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens] + ) + indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens))) + else: + midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length) + midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2) + tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2] + indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2)) + return tokens.unsqueeze(dim=0), indices + + +# Break total_length into hops/windows of size n_ctx separated by hop_length +def get_starts(total_length, n_ctx, hop_length): + starts = [] + for start in range(0, total_length - n_ctx + hop_length, hop_length): + if start + n_ctx >= total_length: + # Last hop could be smaller, we make it n_ctx to maximise context + start = total_length - n_ctx + starts.append(start) + return starts + + +def get_alignment(music_tokens, labels, prior, config): + level = prior.levels - 1 # Top level used + n_ctx = prior.n_ctx + tokens = music_tokens[level] + batch_size, total_length = tokens.shape[0], tokens.shape[1] + if total_length < n_ctx: + padding_length = n_ctx - total_length + tokens = torch.cat( + [tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1 + ) + total_length = tokens.shape[1] + else: + padding_length = 0 + + hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx) + alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0] + attn_layers = set([alignment_layer]) + alignment_hops = {} + indices_hops = {} + for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "): + end = start + n_ctx + # set metadata offset, sample_length and lyrics tokens + metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0) + tokens_bs = torch.chunk(tokens, batch_size, dim=0) + metadata_bs = torch.chunk(metadata, batch_size, dim=0) + w_hops = [] + for tokens_i, metadata_i in zip(tokens_bs, metadata_bs): + w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers) + w_hops.append(w_hop[0][:, alignment_head]) + del w_hop + weights = torch.cat(w_hops, dim=0) + del w_hops + alignment_hop = weights.float().cpu().numpy() + del weights + + # alignment_hop has shape (bs, n_ctx, nb_relevant_lyric_tokens) + # indices_hop is a list of len=bs, each entry of len hps.nb_relevant_lyric_tokens + indices_hops[start] = indices_hop + alignment_hops[start] = alignment_hop + + # Combine attn for each hop into attn for full range + # Use indices to place them into correct place for corresponding source tokens + alignments = [] + for item in range(batch_size): + # Note each item has different length lyrics + full_tokens = labels[0, 3:] + alignment = np.zeros((total_length, len(full_tokens) + 1)) + for start in reversed(get_starts(total_length, n_ctx, hop_length)): + end = start + n_ctx + alignment_hop = alignment_hops[start][item] + indices = indices_hops[start][item] + alignment[start:end, indices] = alignment_hop + alignment = alignment[: total_length - padding_length, :-1] # remove token padding, and last lyric index + alignments.append(alignment) + return alignments + + +def save_temp_audio(fname, lvl, metas, aud): + aud = torch.clamp(aud, -1, 1).cpu().numpy() + for i in list(range(aud.shape[0])): + if metas is not None: + artists, genres, lyrics = list(metas)[i].values() + path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}" + np.save(path, aud[i]) + else: + np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i]) + + +def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t): + # returns a mask of shape 1 x 1 x query_length x key_value_length or None if masking is not needed. + if mask is None or query_length == 1: + return None + offset = sample_t - query_length if sample else max(key_value_length - query_length, 0) + if mask == "autoregressive": + # Masked dense + mask = torch.ones(query_length, key_value_length, device=device).tril(offset) + elif mask == "summary": + # Masked summary + mask = torch.ones(query_length, query_length, device=device).tril() + mask = torch.ones(query_length, query_length, device=device).tril() + mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :] + mask = ( + torch.nn.functional.pad( + mask, + (0, 0, 1, 0), + value=1, + ) + .contiguous() + .view(query_length, key_value_length) + ) + elif mask == "prime": + mask = torch.ones(query_length, key_value_length, device=device).tril(offset) + return mask.view(1, 1, query_length, key_value_length) + + +class JukeboxConv1D(nn.Module): + def __init__(self, input_width, output_width): + super().__init__() + self.input_width = input_width + self.output_width = output_width + weight = torch.empty(input_width, output_width) + bias = torch.zeros(output_width) + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) + + def forward(self, hidden_states): + size_out = (*hidden_states.size()[:-1], self.output_width) + hidden_states = torch.addmm( + self.bias.type_as(hidden_states), + hidden_states.view(-1, hidden_states.size(-1)), + self.weight.type_as(hidden_states), + ) + hidden_states = hidden_states.view(*size_out) + return hidden_states + + +class JukeboxResConv1DBlock(nn.Module): + def __init__(self, config, conv_width, depth=1, res_scale=1.0): + super().__init__() + hidden_dim = config.res_convolution_multiplier * conv_width + dilation = config.res_dilation_growth_rate**depth + padding = dilation + + self.res_scale = res_scale + self.activation = nn.ReLU() + self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation) + self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0) + + def forward(self, hidden_states): + residuals = hidden_states + hidden_states = self.activation(hidden_states) + hidden_states = self.conv1d_1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv1d_2(hidden_states) + return residuals + self.res_scale * hidden_states + + +class JukeboxResnet1D(nn.Module): + def __init__(self, config, conv_width, n_depth, reverse_dilation=False): + super().__init__() + self.dilation_cycle = config.res_dilation_cycle + res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth) + + blocks = [] + for depth in range(n_depth): + block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle + blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale)) + + if reverse_dilation: + blocks = blocks[::-1] + self.resnet_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + for block in self.resnet_block: + hidden_states = block(hidden_states) + return hidden_states + + +class JukeboxEncoderConvBlock(nn.Module): + def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t): + super().__init__() + blocks = [] + filter_t = stride_t * 2 + pad_t = stride_t // 2 + if down_t > 0: + for i in range(down_t): + blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t)) + blocks.append(JukeboxResnet1D(config, hidden_dim, depth)) + self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1) + self.downsample_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + for block in self.downsample_block: + hidden_states = block(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + + +class JukeboxEncoder(nn.Module): + def __init__(self, config, width, depth, levels, downs_t, strides_t): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + + iterator = zip(list(range(self.levels)), downs_t, strides_t) + for i, down_t, stride_t in iterator: + self.level_blocks.append( + JukeboxEncoderConvBlock( + config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t + ) + ) + + def forward(self, hidden_states): + all_hidden_states = [] + + # 64, 32, ... + for level in range(self.levels): + level_block = self.level_blocks[level] + hidden_states = level_block(hidden_states) + all_hidden_states.append(hidden_states) + + return all_hidden_states + + +class JukeboxDecoderConvBock(nn.Module): + def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True): + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + super().__init__() + blocks = [] + if down_t > 0: + filter_t = stride_t * 2 + pad_t = stride_t // 2 + self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1) + for i in range(down_t): + blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation)) + blocks.append( + nn.ConvTranspose1d( + hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t + ) + ) + self.upsample_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + for block in self.upsample_block: + hidden_states = block(hidden_states) + return hidden_states + + +class JukeboxDecoder(nn.Module): + def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t): + self.level_blocks.append( + JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t) + ) + + self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1) + + def forward(self, hidden_states, all_levels=True): + hidden_state = hidden_states[-1] + + # 32, 64 ... + for level in reversed(range(self.levels)): + level_block = self.level_blocks[level] + hidden_state = level_block(hidden_state) + + if level != 0 and all_levels: + hidden_state = hidden_state + hidden_states[level - 1] + + hidden_state = self.out(hidden_state) + return hidden_state + + +class JukeboxBottleneckBlock(nn.Module): + def __init__(self, config: JukeboxVQVAEConfig): + super().__init__() + self.nb_discrete_codes = config.nb_discrete_codes + self.codebook_width = config.embed_dim + self.mu = config.lmu + self.threshold = 1.0 + self.init = False + self.codebook_sum = None + self.codebook_elem = None + self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width)) + + def _tile(self, hidden_states): + dim, embed_width = hidden_states.shape + if dim < self.nb_discrete_codes: + n_repeats = (self.nb_discrete_codes + dim - 1) // dim + std = 0.01 / np.sqrt(embed_width) + hidden_states = hidden_states.repeat(n_repeats, 1) + hidden_states = hidden_states + torch.randn_like(hidden_states) * std + return hidden_states + + def init_codebook(self, hidden_states): + nb_discrete_codes = self.nb_discrete_codes + self.init = True + codes = self._tile(hidden_states) + self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes] + self.codebook_sum = self.codebook + self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device) + + def update_codebook(self, hidden_states, latent_states): + mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes + with torch.no_grad(): + # Calculate new centres + # nb_discrete_codes, batch_size * seq_length + latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device) + latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1) + + _codebook_sum = torch.matmul(latent_states_onehot, hidden_states) + _codebook_elem = latent_states_onehot.sum(dim=-1) # nb_discrete_codes + codes = self._tile(hidden_states) + _random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes] + + # Update centres + old_codebook = self.codebook + self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum + self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem # nb_discrete_codes + usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float() + + norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view( + nb_discrete_codes, 1 + ) + self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook + _codebook_prob = _codebook_elem / torch.sum(_codebook_elem) # prob of each bin + entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse + used_curr = (_codebook_elem >= self.threshold).sum() + usage = torch.sum(usage) + dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape)) + return dict(entropy=entropy, used_curr=used_curr, usage=usage, dk=dk) + + def preprocess(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1).contiguous() + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if hidden_states.shape[-1] == self.codebook_width: + prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape)) + elif hidden_states.shape[-1] == 2 * self.codebook_width: + x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :] + prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( + torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) + ) + + # Normalise + hidden_states = x1 + x2 + + return hidden_states, prenorm + + def postprocess(self, latent_states, dequantised_states, x_shape): + batch_size, time = x_shape + dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous() + latent_states = latent_states.view(batch_size, time) + return latent_states, dequantised_states + + def quantise(self, latent_states): + # Calculate latent code latent_states + codebook_weights = self.codebook.t() + distance = ( + torch.sum(latent_states**2, dim=-1, keepdim=True) + - 2 * torch.matmul(latent_states, codebook_weights) + + torch.sum(codebook_weights**2, dim=0, keepdim=True) + ) # (batch_size * latent_states , codebook_weights) + min_distance, music_tokens = torch.min(distance, dim=-1) + fit = torch.mean(min_distance) + return music_tokens, fit + + def dequantise(self, music_tokens): + dequantised_states = F.embedding(music_tokens, self.codebook) + return dequantised_states + + def encode(self, latent_states): + samples, _, seq_len = latent_states.shape + + # Preprocess. + latent_states, _ = self.preprocess(latent_states) + + # Quantise + music_tokens, _ = self.quantise(latent_states) + + # Postprocess. + music_tokens = music_tokens.view(samples, seq_len) + return music_tokens + + def decode(self, music_tokens): + samples, seq_len = music_tokens.shape + + # Dequantise + dequantised_states = self.dequantise(music_tokens) + + # Postprocess + dequantised_states = ( + dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous() + ) + return dequantised_states + + def forward(self, hidden_states, update_codebook=True): + samples, _, seq_len = hidden_states.shape + + # Preprocess + hidden_states, prenorm = self.preprocess(hidden_states) + + # Init codebook if not inited + if update_codebook and not self.init: + self.init_codebook(hidden_states) + + # Quantise and dequantise through bottleneck + music_tokens, fit = self.quantise(hidden_states) + dequantised_states = self.dequantise(music_tokens) + + # Update embeddings + if update_codebook: + update_metrics = self.update_codebook(hidden_states, music_tokens) + else: + update_metrics = {} + + # Loss + commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape) + + # Passthrough + dequantised_states = hidden_states + (dequantised_states - hidden_states).detach() + + # Postprocess + music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len)) + return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics) + + +class JukeboxBottleneck(nn.Module): + def __init__(self, config, levels): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + for level in range(self.levels): + self.level_blocks.append(JukeboxBottleneckBlock(config)) + + def encode(self, raw_audio): + music_tokens = [ + level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio) + ] + return music_tokens + + def decode(self, music_tokens, start_level=0, end_level=None): + if end_level is None: + end_level = self.levels + quantised_audio = [ + level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens) + ] + return quantised_audio + + def forward(self, input_audio): + music_tokens, quantised_states, commit_losses, metrics = [], [], [], [] + for level in range(self.levels): + level_block = self.level_blocks[-level - 1] + hidden_states = input_audio[level] + sampled_tokens, quantised_state, commit_loss, metric = level_block( + hidden_states, update_codebook=self.training + ) + music_tokens.append(sampled_tokens) + if not self.training: + # Be extra paranoid and make sure the encoder weights can't + # change from straight-through estimator + quantised_state = quantised_state.detach() + quantised_states.append(quantised_state) + commit_losses.append(commit_loss) + if self.training: + metrics.append(metric) + return music_tokens, quantised_states, commit_losses, metrics + + +JUKEBOX_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (`JukeboxConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam +Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/abs/2002.08111). + + """, + JUKEBOX_START_DOCSTRING, +) +class JukeboxVQVAE(PreTrainedModel): + config_class = JukeboxVQVAEConfig + base_model_prefix = "vqvae" + _keys_to_ignore_on_load_unexpected = [r"priors"] + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): # embed_tokens + module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + elif isinstance(module, JukeboxConv1D): + if self.config.zero_out: + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: + module.conv1d_2.weight.data.zero_() + module.conv1d_2.bias.data.zero_() + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: JukeboxVQVAEConfig): + super().__init__(config) + downs_t = config.res_downs_t + strides_t = config.res_strides_t + if not config.sample_length: + downsamples = [stride**down for stride, down in zip(strides_t, downs_t)] + top_raw_to_tokens = np.prod(downsamples) + config.sample_length = ( + config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens + ) * top_raw_to_tokens + config.sample_length = config.sample_length.astype(int) + + self.nb_discrete_codes = config.nb_discrete_codes + self.commit = config.commit + self.sample_length = config.sample_length + + self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)] + self.hop_lengths = np.cumprod(self.downsamples) + self.levels = levels = config.levels + self.music_tokens_shapes = [ + (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels) + ] + + self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + for level in range(levels): + width = config.res_conv_width * self.multipliers[level] + depth = config.res_conv_depth * self.multipliers[level] + self.encoders.append( + JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1]) + ) + self.decoders.append( + JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1]) + ) + + self.bottleneck = JukeboxBottleneck(config, levels) + + def _decode(self, music_tokens, start_level=0, end_level=None): + # Decode + if end_level is None: + end_level = self.levels + latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level) + # Use only lowest level + decoder, dequantised_state = self.decoders[start_level], latent_states[0:1] + dequantised_state = decoder(dequantised_state, all_levels=False) + dequantised_state = dequantised_state.permute(0, 2, 1) + return dequantised_state + + def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor: + """ + Transforms the input `music_tokens` to their `raw_audio` representation. + + Args: + music_tokens (`torch.LongTensor`): + Tensor of music tokens which will be decoded to raw audio by using the codebook. Each music token + should be an index to a corresponding `code` vector in the codebook. + start_level (`int`, *optional*): + Level at which the decoding process will start. Default to 0. + end_level (`int`, *optional*): + Level at which the decoding process will start. Default to None. + bs_chunks (int, *optional*): + Number of chunks to process at the same time. + """ + token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens] + dequantised_states = [] + for i in range(bs_chunks): + music_tokens_i = [chunks[i] for chunks in token_chunks] + dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level) + dequantised_states.append(dequantised_state) + return torch.cat(dequantised_states, dim=0) + + def _encode(self, raw_audio, start_level=0, end_level=None): + # Encode + if end_level is None: + end_level = self.levels + input_audio = raw_audio.permute(0, 2, 1).float() + latent_states = [] + for level in range(self.levels): + encoder = self.encoders[level] + latent_state = encoder(input_audio) + latent_states.append(latent_state[-1]) + music_tokens = self.bottleneck.encode(latent_states) + return music_tokens[start_level:end_level] + + def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1): + """ + Transforms the `input_audio` to a discrete representation made out of `music_tokens`. + + Args: + input_audio (`torch.Tensor`): + Raw audio which will be encoded to its discrete representation using the codebook. The closest `code` + form the codebook will be computed for each sequence of samples. + start_level (`int`, *optional*, defaults to 0): + Level at which the encoding process will start. Default to 0. + end_level (`int`, *optional*): + Level at which the encoding process will start. Default to None. + bs_chunks (int, *optional*, defaults to 1): + Number of chunks of raw audio to process at the same time. + """ + audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0) + music_tokens_list = [] + for chunk_i in audio_chunks: + music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level) + music_tokens_list.append(music_tokens_i) + music_tokens = [torch.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)] + return music_tokens + + def sample(self, n_samples): + music_tokens = [ + torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu") + for music_tokens_shape in self.music_tokens_shapes + ] + return self.decode(music_tokens) + + def forward(self, raw_audio): + """ + Forward pass of the VQ-VAE, encodes the `raw_audio` to latent states, which are then decoded for each level. + The commit loss, which ensure that the encoder's computed embeddings are close to the codebook vectors, is + computed. + + Args: + raw_audio (`torch.FloatTensor`): + Audio input which will be encoded and decoded. + + Returns: + `Tuple[torch.Tensor, torch.Tensor` + + + Example: + ```python + >>> from transformers import JukeboxVQVAE, set_seed + >>> import torch + + >>> model = JukeboxVQVAE.from_pretrained("openai/jukebox-1b-lyrics").eval() + >>> set_seed(0) + >>> zs = [torch.randint(100, (4, 1))] + >>> model.decode(zs).shape + torch.Size([4, 8, 1]) + ``` + """ + + # Encode/Decode + input_audio = raw_audio.permute(0, 2, 1).float() + latent_states = [] + for level in range(self.levels): + encoder = self.encoders[level] + latent_state = encoder(input_audio) + latent_states.append(latent_state[-1]) + + _, music_tokens, commit_losses, _ = self.bottleneck(latent_states) + dequantised_states = [] + for level in range(self.levels): + decoder = self.decoders[level] + dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False) + dequantised_states.append(dequantised_state.permute(0, 2, 1)) + + commit_loss = sum(commit_losses) + loss = self.commit * commit_loss + + return dequantised_states, loss + + +class JukeboxMLP(nn.Module): + def __init__(self, config): + # a single channel is always used in original code + super().__init__() + embed_dim = config.hidden_size + hidden_dim = int(config.mlp_multiplier * embed_dim) + + self.c_fc = JukeboxConv1D(embed_dim, hidden_dim) + self.c_proj = JukeboxConv1D(hidden_dim, embed_dim) + self.act = ACT2FN[config.act_fn] + self.dropout = nn.Dropout(config.resid_dropout) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class JukeboxLayerNorm(FusedLayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + self.width = np.prod(normalized_shape) + self.max_numel = 65535 * self.width + + def forward(self, input): + if input.numel() > self.max_numel: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input) + else: + return super().forward(input).type_as(input) + + +class JukeboxAttention(nn.Module): + def __init__(self, config, n_ctx, attn_func="dense_attn"): + super().__init__() + self.embed_dim = config.hidden_size + self.n_heads = config.n_heads + self.dropout = config.attn_dropout + hidden_dim = int(config.attention_multiplier * self.embed_dim) + + self.head_dim = hidden_dim // config.n_heads + self.n_ctx = n_ctx + self.hidden_dim = hidden_dim + self.scale = self.head_dim**-0.25 + self.mask = config.mask + + if attn_func == "cross_attention": + self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim) + self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2) + else: + self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3) + + self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim) + self.attn_dropout = nn.Dropout(config.attn_dropout) + self.resid_dropout = nn.Dropout(config.resid_dropout) + + # Sequence of length seq_len is factored as [blocks, seq_len // blocks] + self.attn_func = attn_func + if attn_func == "cross_attention": + self.qkv = self.decode_qkv + elif attn_func == "prime_attn": + self.qkv = self.prime_qkv + else: + self.qkv = self.factored_qkv + + ATTENTION_MAP = { + "dense_attn": (self.dense_attn, "autoregressive"), + "block_attn": (self.block_attn, "autoregressive"), + "transpose_block_attn": (self.transpose_block_attn, "autoregressive"), + "prev_block_attn": (self.prev_block_attn, None), + "summary_attn": (self.summary_attn, "summary"), + "summary_spread_attn": (self.summary_spread_attn, "summary"), + "cross_attention": (self.dense_attn, None), + "prime_attn": (self.prime_attn, "prime"), + } + self.attn, self.attn_mask = ATTENTION_MAP[attn_func] + + self.blocks = config.blocks + self.spread = config.spread + if self.blocks is not None: + self.block_ctx = self.n_ctx // self.blocks + + self.sample_t = 0 + self.cache = {} + self.encoder_len = config.nb_relevant_lyric_tokens # length of the encoder input ids + self.record_attn = False + + def _attn(self, query_states, key_states, value_states, sample): + scale = self.scale + if self.training: + attention_weight = torch.matmul(query_states * scale, key_states * scale) + else: + attention_weight = torch.matmul(query_states, key_states) + attention_weight.mul_(scale * scale) + attn_weight_type = attention_weight.dtype + attention_weight = attention_weight.float() + if self.mask: + # Generate appropriate mask to mask out all positions before current + # Might take up lot of memory for dense, so can cache it + mask = get_mask( + self.attn_mask, + query_states.size(-2), + key_states.size(-1), + self.blocks, + self.spread, + attention_weight.device, + sample, + self.sample_t, + ) + if mask is not None: + attention_weight = attention_weight * mask + -1e9 * (1 - mask) + attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type) + if self.record_attn: + self.attention_prob = attention_prob + if self.attn_func == "prime_attn": + # only keep music queries and lyrics keys/values + self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len] + attention_prob = self.attn_dropout(attention_prob) + context_states = torch.matmul(attention_prob, value_states) + return context_states + + def merge_heads(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1)) + return hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, hidden_states, is_key=False): + new_hidden_states_shape = ( + *hidden_states.size()[:-1], + self.n_heads, + hidden_states.size(-1) // self.n_heads, + ) + hidden_states = hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct split_states + if is_key: + return hidden_states.permute(0, 2, 3, 1) + else: + return hidden_states.permute(0, 2, 1, 3) + + def dense_attn(self, query, key, value, sample): + query = self.split_heads(query) + key = self.split_heads(key, is_key=True) + value = self.split_heads(value) + context_states = self._attn(query, key, value, sample) + context_states = self.merge_heads(context_states) + return context_states + + def block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim) + if query_length < seq_len: + seq_len = query_length + key = key[:, -seq_len:].contiguous() + value = value[:, -seq_len:].contiguous() + key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def transpose_block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + block_len = (seq_len - 1) % block_ctx + key = key[:, block_len::block_ctx, :] + value = value[:, block_len::block_ctx, :] + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim) + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim) + + key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim) + key = key.transpose(1, 2).contiguous() + key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim) + + value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim) + value = value.transpose(1, 2).contiguous() + value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim) + + block_attn = self.dense_attn(query, key, value, sample) + block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim) + block_attn = block_attn.transpose(1, 2).contiguous() + block_attn = block_attn.view(batch_size, query_length, embed_dim) + + return block_attn + + def prev_block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + block = (seq_len - 1) // block_ctx + prev_l = (block - 1) * block_ctx + if block > 0: + key = key[:, prev_l : prev_l + block_ctx, :] + value = value[:, prev_l : prev_l + block_ctx, :] + else: + key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype) + value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype) + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim) + + key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :] + key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)) + key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + + value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :] + value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)) + value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + + if query_length < seq_len: + nb_query_blocks = query_length // block_ctx + nb_key_blocks = seq_len // block_ctx + seq_len = query_length + key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:] + key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim) + + value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:] + value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim) + + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def summary_attn(self, query, key, value, sample): + blocks = self.blocks + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + key = key[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :] + key = torch.nn.functional.pad(key, (0, 0, 1, 0)) + + value = value[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :] + value = torch.nn.functional.pad(value, (0, 0, 1, 0)) + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :] + key = torch.nn.functional.pad(key, (0, 0, 1, 0)) # batch_size, blocks, embed_dim + + value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :] + value = torch.nn.functional.pad(value, (0, 0, 1, 0)) # batch_size, blocks, embed_dim + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def summary_spread_attn(self, query, key, value, sample): + blocks = self.blocks + spread = self.spread + + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + raise NotImplementedError + else: + key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :] + key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous() + key = key.view(batch_size, blocks * spread, embed_dim) + + value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :] + value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous() + value = value.view(batch_size, blocks * spread, embed_dim) + + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def prime_attn(self, query, key, value, sample): + encoder_len = self._encoder_len + key = key[:, :encoder_len] + value = value[:, :encoder_len] + return self.dense_attn(query, key, value, sample) + + def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + if last_encoder_hidden_states is not None: + raise TypeError("last_encoder_hidden_states should be None") + + query, key, value = hidden_states.chunk(3, dim=2) + if sample: + self.sample_t += curr_ctx + key, value = self._append_cache(key, value) + l_cache = self._suff_cache_len() + if self._cache_len() > l_cache: + self._slice_cache(-l_cache) + if curr_ctx > 1: + if self.attn_func != "dense_attn": + query = self._pad_to_block_ctx(query, query=True) + key = self._pad_to_block_ctx(key) + value = self._pad_to_block_ctx(value) + sample = False + else: + key = self.cache["key"] + value = self.cache["value"] + return query, key, value, sample + + def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + if last_encoder_hidden_states is not None: + raise TypeError("last_encoder_hidden_states should be None") + query, key, value = hidden_states.chunk(3, dim=2) + if sample: + if self._cache_len() < self._encoder_len: + self._append_cache(key, value) + if self._cache_len() > self._encoder_len: + self._slice_cache(0, self._encoder_len) + key, value = self.cache["key"], self.cache["value"] + self.sample_t += curr_ctx + return query, key, value, sample + + def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + query = hidden_states + if sample: + if self.sample_t == 0: + self.cache["key"], self.cache["value"] = self.c_enc_kv( + last_encoder_hidden_states.type_as(hidden_states) + ).chunk(2, dim=2) + key, value = self.cache["key"], self.cache["value"] + self.sample_t += curr_ctx + else: + key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2) + return query, key, value, sample + + def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + hidden_states = self.c_attn(hidden_states) + query, key, value, sample = self.qkv( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample + ) + attention_scores = self.attn(query, key, value, sample) + if attention_scores.shape[1] != curr_ctx: + offset = self._offset(curr_ctx) + attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous() + attention_scores = self.c_proj(attention_scores) + return self.resid_dropout(attention_scores) + + @property + def _encoder_len(self): + encoder_len = self.encoder_len + encoder_blocks = (encoder_len // self.blocks) + 1 + return encoder_blocks * self.blocks + + def _offset(self, curr_ctx): + if self.attn_func == "dense_attn": + return 0 + return (self.sample_t - curr_ctx) % self.block_ctx + + def _pad_to_block_ctx(self, hidden_states, query=False): + seq_len = hidden_states.shape[1] + offset = self._offset(seq_len) if query else 0 + n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx + pad = n_blocks * self.block_ctx - seq_len - offset + if pad == 0 and offset == 0: + return hidden_states + else: + return F.pad(hidden_states, (0, 0, offset, pad)) + + def _cache_len(self): + return 0 if "key" not in self.cache else self.cache["key"].shape[1] + + def _suff_cache_len(self): + """ + Precondition: + key and value are appended with the current context and self.sample_t reflects the 1-indexed sample + location in the context. + """ + previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx + REQUIRED_CACHE_LEN = { + "dense_attn": self.sample_t, + "block_attn": (self.sample_t - 1) % self.block_ctx + 1, + "transpose_block_attn": self.sample_t, + "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length, + "cross_attn": self.encoder_len, + "prime_attn": min(self.sample_t, self._encoder_len), + } + + return REQUIRED_CACHE_LEN[self.attn_func] + + def _slice_cache(self, start, end=None): + self.cache["key"] = self.cache["key"][:, start:end] + self.cache["value"] = self.cache["value"][:, start:end] + + def _append_cache(self, key, value): + if "key" not in self.cache: + self.cache["key"] = key + self.cache["value"] = value + else: + old_key, old_value = key, value + key = torch.cat([self.cache["key"], old_key], dim=1) + value = torch.cat([self.cache["value"], old_value], dim=1) + del self.cache["key"] + del self.cache["value"] + del old_key + del old_value + self.cache["key"] = key + self.cache["value"] = value + return self.cache["key"], self.cache["value"] + + def del_cache(self): + self.sample_t = 0 + if "key" in self.cache: + del self.cache["key"] + if "value" in self.cache: + del self.cache["value"] + self.cache = {} + + +class JukeboxBlock(nn.Module): + def __init__(self, config, n_ctx, attn_func="dense_attn"): + super().__init__() + self.width = config.hidden_size + self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func) + + self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size) + self.mlp = JukeboxMLP(config) + self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size) + self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0 + self.attn_func = attn_func + + def forward(self, hidden_states, last_encoder_hidden_states, sample=False): + residuals = hidden_states + hidden_states = self.layer_norm_0(hidden_states) + hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample) + + output_states = self.layer_norm_1(residuals + hidden_states) + output_states = self.mlp(output_states) + if self.res_scale == 1.0: + output = residuals + hidden_states + output_states + else: + output = residuals + self.res_scale * (hidden_states + output_states) + return output + + +class JukeboxLayerStack(nn.Module): + def __init__(self, config, n_ctx): + super().__init__() + self.n_ctx = n_ctx + self.width = config.hidden_size + self.num_layers = config.num_layers + self.blocks = config.blocks + self.attention_pattern = config.attention_pattern + if self.blocks is not None: + self.block_ctx = n_ctx // self.blocks + self.encoder_len = config.nb_relevant_lyric_tokens + self.n_heads = config.n_heads + + # Orders of attn_func + attention_pattern = ATTENTION_PATTERNS[self.attention_pattern] + self._attn_mods = nn.ModuleList() + for depth in range(self.num_layers): + self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth))) + + self.saved_attn_weights = [] + + def set_record_attn(self, record_attn): + """ + Makes forward prop dump self-attention softmaxes to self.saved_attn_weights. + + Args: + record_attn (`Union[bool,set]`): + Either a set of layer indices indicating which layers to store, or a boolean value indicating Whether + to dump all. + """ + + def _should_record_attn(layer_idx): + if isinstance(record_attn, bool): + return record_attn + return layer_idx in record_attn + + for i, layer in enumerate(self._attn_mods): + layer.attn.record_attn = _should_record_attn(i) + + if not record_attn: + self.saved_attn_weights = [] + + def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False): + # Blocks + for i, attn_layer in enumerate(self._attn_mods): + if attn_layer.attn_func == "cross_attention": # attend to the lyrics + hidden_states = attn_layer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample + ) + else: + hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample) + if attn_layer.attn.record_attn: + self.saved_attn_weights.append(attn_layer.attn.c_attn.weight) + return hidden_states + + def del_cache(self): + for attn_layer in self._attn_mods: + attn_layer.attn.del_cache() + + +class JukeboxPositionalEmbedding(nn.Module): + def __init__(self, embed_dim, width): + super().__init__() + self.pos_emb = nn.Parameter(torch.empty((embed_dim, width))) + + def forward(self): + pos_emb = self.pos_emb + return pos_emb + + +class JukeboxConditionalAutoregressive(nn.Module): + def __init__( + self, + config, + n_ctx=None, + embed_dim=None, + audio_conditioning=False, + metadata_conditioning=False, + is_encoder=False, + ): + """ + Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly + set fro each configuration. + + Args: + config (`JukeboxPriorConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does + not load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + n_ctx (`int`, *optional*): + Number of tokens or lyrics tokens provided in a single pass. + embed_dim (`int`, *optional*): + Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codeboook dimension, + if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder + audio_conditioning (`bool`, *optional*, defaults to `False`): + Whether or not the prior supports conditionning on audio. + metadata_conditioning (`bool`, *optional*, defaults to `False`): + Whether or not the prior supports conditionning on artitst, genres, lyrics and timing. + is_encoder (`bool`, *optional*, defaults to `False`): + Whether the model is an encoder only model. + """ + + super().__init__() + self.width = config.hidden_size + self.num_layers = config.num_layers + self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx + self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size + self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size) + self.embed_tokens_dropout = nn.Dropout(config.emb_dropout) + self.metadata_conditioning = metadata_conditioning + self.audio_conditioning = audio_conditioning + if not metadata_conditioning: + self.start_token = nn.Parameter(torch.empty((1, config.hidden_size))) + self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size) + self.pos_emb_dropout = nn.Dropout(config.emb_dropout) + + self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx) + self.is_encoder = is_encoder + self.encoder_len = config.nb_relevant_lyric_tokens + + if config.merged_decoder: + # Merged piped model uses this setup + self.add_cond_after_transformer = False + self.share_embed_tokens_fc_proj_out = False + else: + self.add_cond_after_transformer = True + self.share_embed_tokens_fc_proj_out = True + + if not is_encoder: + self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False) + if self.share_embed_tokens_fc_proj_out: + self.fc_proj_out.weight = self.embed_tokens.weight + self.loss = torch.nn.CrossEntropyLoss() + + def forward( + self, + tokens, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + get_preds=False, + get_acts=False, + get_sep_loss=False, + ): + """ + Args: + tokens (`torch.tensor`): + Can represent music tokens, lyrics tokens or both, depending on the configuration. + """ + # Preprocess. + batch_size = tokens.shape[0] + with torch.no_grad(): + tokens = tokens.view(batch_size, -1).long() + + if not self.audio_conditioning: + audio_conditioning = torch.zeros( + (batch_size, 1, self.width), + device=tokens.device, + dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype, + ) + + target = tokens # Target + hidden_states = self.embed_tokens(tokens) + # Shift by 1, and fill in start token + hidden_states = torch.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1) + if self.metadata_conditioning: + hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width) + else: + hidden_states[:, 0] = self.start_token + + hidden_states = ( + self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning + ) # Pos emb and dropout + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states + ) # Transformer + if self.add_cond_after_transformer: # Piped doesnt add x_cond + hidden_states = hidden_states + audio_conditioning + + activations = hidden_states + if self.is_encoder: + return hidden_states + + hidden_states = self.fc_proj_out(hidden_states) # Predictions + loss_fn = nn.CrossEntropyLoss() + if get_sep_loss: + lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim) + token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim) + + lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0) + music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0) + + loss = (lyric_loss, music_token_loss) # Note order! Lyric is first + else: + loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0) # Loss + + if get_preds: + return loss, hidden_states + elif get_acts: + return loss, activations + else: + return loss, None + + def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning): + if sample_t == 0: + hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to( + self.embed_tokens.weight.device + ) + if self.metadata_conditioning: + hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width) + else: + hidden_states[:, 0] = self.start_token + else: + hidden_states = self.embed_tokens(tokens) + if audio_conditioning.shape == (n_samples, self.n_ctx, self.width): + cond = audio_conditioning[:, sample_t : sample_t + 1, :] + else: + cond = audio_conditioning + # Pos emb, dropout is identity at eval time + hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond + return hidden_states, cond + + def sample( + self, + n_samples, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + temp=1.0, + top_k=0, + top_p=0.0, + get_preds=False, + sample_tokens=None, + ): + if sample_tokens is None: + sample_tokens = self.n_ctx + + if not self.audio_conditioning: + audio_conditioning = torch.zeros( + (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype + ).to(self.fc_proj_out.device) + + with torch.no_grad(): + sampled_tokens = [] + tokens = None + if get_preds: + preds = [] + + iter = tqdm(range(0, sample_tokens), leave=False) + for sample_t in iter: + iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True) + hidden_states, cond = self.get_emb( + sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning + ) + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True + ) + if self.add_cond_after_transformer: + hidden_states = hidden_states + cond + hidden_states = self.fc_proj_out(hidden_states) # Predictions + if get_preds: + preds.append(hidden_states.clone()) + # Adjust logits + hidden_states = hidden_states / temp + hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p) + # Sample and replace hidden_states + tokens = torch.distributions.Categorical(logits=hidden_states).sample() + sampled_tokens.append(tokens.clone()) + + del tokens + self.transformer.del_cache() + + tokens = torch.cat(sampled_tokens, dim=1) + if get_preds: + preds = torch.cat(preds, dim=1) + if get_preds: + return tokens, preds + else: + return tokens + + def split_chunks(self, length, chunk_size): + n_passes = (length + chunk_size - 1) // chunk_size + chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1] + return chunk_sizes + + def primed_sample( + self, + n_samples, + lyric_and_music_tokens, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + temp=1.0, + top_k=0, + top_p=0.0, + get_preds=False, + chunk_size=None, + sample_tokens=None, + ): + if sample_tokens is None: + sample_tokens = self.n_ctx + # Preprocess. + batch_size = lyric_and_music_tokens.shape[0] + with torch.no_grad(): + lyric_and_music_tokens = lyric_and_music_tokens.view(batch_size, -1).long() + + sampled_audio = torch.split(lyric_and_music_tokens, 1, dim=1) + sampled_audio = list(sampled_audio) + + if not self.audio_conditioning: + audio_conditioning = torch.zeros( + (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype + ).to(lyric_and_music_tokens.device) + + with torch.no_grad(): + if get_preds: + preds = [] + + # Fill up key/value cache for past context by runing forward pass. + # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage. + if chunk_size is None: + chunk_size = len(sampled_audio) + chunk_sizes = self.split_chunks(len(sampled_audio), chunk_size) + x_primes = [] + start = 0 + token = None + + for current_chunk_size in tqdm(chunk_sizes, desc="Preparing past key value", leave=False): + sampled_audio_prime, conds_prime = [], [] + for sample_t in range(start, start + current_chunk_size): + x_prime, cond_prime = self.get_emb( + sample_t, n_samples, token, audio_conditioning, metadata_conditioning + ) + token = sampled_audio[sample_t] + sampled_audio_prime.append(x_prime) + conds_prime.append(cond_prime) + start = start + current_chunk_size + x_prime, cond_prime = torch.cat(sampled_audio_prime, dim=1), torch.cat(conds_prime, dim=1) + del sampled_audio_prime + del conds_prime + if not get_preds: + del cond_prime + x_prime = self.transformer(x_prime, last_encoder_hidden_states=last_encoder_hidden_states, sample=True) + + if get_preds: + if self.add_cond_after_transformer: + x_prime = x_prime + cond_prime + del cond_prime + x_primes.append(x_prime) + else: + del x_prime + + if get_preds: + x_prime = torch.cat(x_primes, dim=1) + x_prime = self.fc_proj_out(x_prime) # Predictions + preds.append(x_prime) + + # the input of the encoder and decoder can be merged into (lyrics, music tokens) + input_tokens = sampled_audio[-1] + + itererator = tqdm( + range(len(sampled_audio), sample_tokens), + desc=f"Sampling {len(range(len(sampled_audio), sample_tokens))} music tokens", + leave=False, + ) + for sample_t in itererator: + hidden_states, cond = self.get_emb( + sample_t, n_samples, input_tokens, audio_conditioning, metadata_conditioning + ) + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True + ) + if self.add_cond_after_transformer: + hidden_states = hidden_states + cond + hidden_states = self.fc_proj_out(hidden_states) # Predictions + if get_preds: + preds.append(hidden_states) + # Adjust logits + hidden_states = hidden_states / temp + hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p) + # only music tokens are sampled + music_tokens = torch.distributions.Categorical(logits=hidden_states).sample() + sampled_audio.append(music_tokens.clone()) + input_tokens = music_tokens + + del input_tokens, music_tokens + self.transformer.del_cache() + + music_tokens = torch.cat(sampled_audio, dim=1) + if get_preds: + preds = torch.cat(preds, dim=1) + if get_preds: + return music_tokens, preds + else: + return music_tokens + + +class JukeboxMusicTokenConditioner(nn.Module): + """ + The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's + codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE). + """ + + def __init__(self, config, level): + + super().__init__() + self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size) + config.embed_dim = config.music_vocab_size # setting correct argument for the `JukeboxDecoder` + + self.upsampler = JukeboxDecoderConvBock( + config, + config.hidden_size, + config.res_conv_width, + config.res_conv_depth, + config.res_downs_t[level], + config.res_strides_t[level], + reverse_dilation=False, + ) + self.layer_norm = JukeboxLayerNorm(config.hidden_size) + + def forward(self, music_tokens, raw_audio_conditionning=None): + """ + Args: + music_tokens (`torch.LongTensor`): + Music tokens form the uper level in range(nb_discrete_codes) + raw_audio_conditionning (`torch.LongTensor`, *optional*): + Audio used when primed sampling, raw audio information that conditions the generation + """ + if raw_audio_conditionning is None: + raw_audio_conditionning = 0.0 + # Embed music_tokens + music_tokens = music_tokens.long() + hidden_states = self.embed_tokens(music_tokens) + hidden_states = hidden_states + raw_audio_conditionning + + # Run conditioner + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class JukeboxRangeEmbedding(nn.Module): + """ + The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional + embedding of length `n_ctx`. + + Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end) + -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <= + end + """ + + def __init__(self, n_time, embed_dim, range, out_width, clamp=False): + super().__init__() + self.n_time = n_time + self.embed_dim = embed_dim + self.emb = nn.Embedding(embed_dim, out_width) + self.pos_min, self.pos_max = range + self.clamp = clamp + + def forward(self, pos_start, pos_end=None): + # Check if [pos_start,pos_end] in [pos_min, pos_max) + if not len(pos_start.shape) == 2: + raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}") + if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all(): + raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}") + + pos_start = pos_start.float() + if pos_end is not None: + if self.clamp: + pos_end = pos_end.clamp(self.pos_min, self.pos_max) + + pos_end = pos_end.float() + # Interpolate so that [pos_start, ..., pos_end] <-> position tensor of length n_ctx + n_time = self.n_time + if n_time != 1: + interpolation = ( + torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time + ) + position = pos_start + (pos_end - pos_start) * interpolation + else: + position = pos_start + + # Bin each value to bins_ + # [0,1) -> [0,1..,embed_dim) -> [0,1...,embed_dim-1 + normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min) + bins_ = (self.embed_dim * normalised_position).floor().long().detach() + return self.emb(bins_) + + +class JukeboxLabelConditioner(nn.Module): + def __init__(self, config, include_time_signal): + super().__init__() + + embed_dim = config.hidden_size + timing_dims = config.timing_dims + sampling_rate = config.sampling_rate + nb_genres, nb_artists = config.metadata_dims + music_tokens_shape = config.n_ctx + + self.max_nb_genres = config.max_nb_genres + self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim) + self.artist_emb = nn.Embedding(nb_artists, embed_dim) + self.include_time_signal = include_time_signal + if self.include_time_signal: + total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate) + absolute_pos_range = (0.0, config.max_duration * sampling_rate) + relative_pos_range = (0.0, 1.0) + self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim) + self.absolute_pos_emb = JukeboxRangeEmbedding( + music_tokens_shape, timing_dims, absolute_pos_range, embed_dim + ) + self.relative_pos_emb = JukeboxRangeEmbedding( + music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True + ) + + def forward(self, metadata): + total_length = metadata[:, 0:1] + offset = metadata[:, 1:2] + length = metadata[:, 2:3] + artist = metadata[:, 3:4] + genre = metadata[:, 4:] + + # Start embedding of length 1 + artist_emb = self.artist_emb(artist) + # Empty genre slots are denoted by -1. We mask these out. + mask = (genre >= 0).float().unsqueeze(2) + genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True) + start_emb = genre_emb + artist_emb + + # Pos embedding of length n_ctx + if self.include_time_signal: + start, end = offset, offset + length + total_length = total_length.float() + start = start.float() + end = end.float() + pos_emb = ( + self.total_length_emb(total_length) + + self.absolute_pos_emb(start, end) + + self.relative_pos_emb(start / total_length, end / total_length) + ) + else: + pos_emb = None + return start_emb, pos_emb + + +class JukeboxPrior(PreTrainedModel): + """ + The JukeboxPrior class, which is a wrapper around the various conditioning and the transformer. JukeboxPrior can be + seen as language models trained on music. They model the next `music token` prediction task. If a (lyric) `encoderù + is defined, it also models the `next character` prediction on the lyrics. Can be conditionned on timing, artist, + genre, lyrics and codes from lower-levels Priors. + + Args: + config (`JukeboxPriorConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + level (`int`, *optional*): + Current level of the Prior. Should be in range `[0,nb_priors]`. + nb_priors (`int`, *optional*, defaults to 3): + Total number of priors. + vqvae_encoder (`Callable`, *optional*): + Encoding method of the VQVAE encoder used in the forward pass of the model. Passing functions instead of + the vqvae module to avoid getting the parameters. + vqvae_decoder (`Callable`, *optional*): + Decoding method of the VQVAE decoder used in the forward pass of the model. Passing functions instead of + the vqvae module to avoid getting the parameters. + """ + + config_class = JukeboxPriorConfig + _keys_to_ignore_on_load_unexpected = ["vqvae"] + + def _init_weights(self, module): + init_scale = self.config.init_scale + + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxConv1D): + if self.config.zero_out: + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxPositionalEmbedding): + module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxRangeEmbedding): + module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"): + module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"): + module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: + module.conv1d_2.weigth.data.zero_() + module.conv1d_2.bias.data.zero_() + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None): + super().__init__(config) + # Passing functions instead of the vqvae module to avoid getting params, only used in the + # forward loop + self.vqvae_encoder = vqvae_encoder + self.vqvae_decoder = vqvae_decoder + + self.levels = nb_priors + self.level = level if level is not None else config.level + + self.base_model_prefix = f"priors.{self.level}" + self._keys_to_ignore_on_load_unexpected += [r"priors.[^%d]." % self.level] + + self.n_ctx = config.n_ctx + + self.lyric_conditioning = config.nb_relevant_lyric_tokens > 0 + self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens + self.encoder_loss_fraction = config.encoder_loss_fraction + + # Audio conditioning : conditioning on music tokens (either from audio or from previous levels or both) + self.audio_conditioning = self.level != 0 + self.cond_level = self.level - 1 + if self.audio_conditioning: + self.conditioner_blocks = JukeboxMusicTokenConditioner(config, self.level) + + # metadata conditioning : contioning on timing, genres, and artist + self.metadata_conditioning = config.metadata_conditioning + if self.metadata_conditioning: + self.metadata_embedding = JukeboxLabelConditioner(config, include_time_signal=not self.audio_conditioning) + + # define encoder-decoder or encoder and decoder + self.is_encoder_decoder = config.is_encoder_decoder + if config.is_encoder_decoder: + # encoder-decoder transformer + self.input_shapes = [config.nb_relevant_lyric_tokens, config.n_ctx] + self.embed_dim_shift = [0, config.lyric_vocab_size] + self.width = config.hidden_size + + self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens + + self.prior = JukeboxConditionalAutoregressive( + config, + n_ctx=config.nb_relevant_lyric_tokens + config.n_ctx, + embed_dim=config.lyric_vocab_size + config.music_vocab_size, + audio_conditioning=(self.audio_conditioning or self.metadata_conditioning), + metadata_conditioning=True, + ) + + else: + # Separate encoder-decoder transformer + encoder_config = config.encoder_config + + if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning: + self.lyric_acts_width = encoder_config.hidden_size + self.encoder_width = config.hidden_size + self.encoder_dim = config.lyric_vocab_size + self.encoder = JukeboxConditionalAutoregressive( + encoder_config, + n_ctx=self.nb_relevant_lyric_tokens, + embed_dim=self.encoder_dim, + audio_conditioning=False, + metadata_conditioning=False, + is_encoder=True, + ) + self.encoder.proj_in = JukeboxConv1D(encoder_config.hidden_size, config.hidden_size) + self.encoder.final_layer_norm = JukeboxLayerNorm(config.hidden_size) + self.encoder.lm_head = nn.Linear(config.hidden_size, config.lyric_vocab_size, bias=False) + else: + self.nb_relevant_lyric_tokens = 0 + + # decoder model on the tokens + self.prior = JukeboxConditionalAutoregressive( + config, + audio_conditioning=(self.audio_conditioning or self.metadata_conditioning), + metadata_conditioning=self.metadata_conditioning, + ) + + self.next_token_prediction_loss_dims = config.n_ctx + self.total_loss_dims = self.nb_relevant_lyric_tokens + self.next_token_prediction_loss_dims + + self.downsamples = [stride**down for stride, down in zip(config.res_strides_t, config.res_downs_t)] + self.cond_downsample = self.downsamples[self.level] if self.level != 0 else None + self.raw_to_tokens = np.prod(self.downsamples[: nb_priors - self.level]) + self.sample_length = self.n_ctx * self.raw_to_tokens + + logger.info( + f"Level:{self.level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample" + f" length:{self.sample_length}" + ) + + def get_metadata(self, labels, start, total_length, offset, get_indices=False): + metadata = labels.clone() + metadata[:, 0] = total_length + # Set sample_length to match this level + metadata[:, 2] = int(self.sample_length) + + # Set offset + metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens) + # here since metadata has the full token_list, we just need to selected the ones that are relevant + + # Set lyric tokens + metadata, indices = self.set_metadata_lyric_tokens(metadata) + if get_indices: + return metadata, indices + else: + return metadata + + def set_metadata_lyric_tokens(self, labels): + """ + Processes the full labels to only retreive the relevant lyric tokens and keep the metadata conditioning tokens. + """ + if self.nb_relevant_lyric_tokens > 0: + tokens_list = torch.zeros( + (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device + ) + indices_list = [] # whats the index of each current character in original array + for idx in range(labels.shape[0]): + full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :] + total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2] + tokens, indices = get_relevant_lyric_tokens( + full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration + ) + tokens_list[idx, :] = tokens + indices_list.append(indices) + + return ( + torch.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1), + indices_list, + ) + else: + return labels, None + + def get_music_tokens_conds(self, music_tokens, start, end): + """ + Extracts current level's conditioning music tokens. + """ + if self.level != 0: + music_tokens_cond = music_tokens[self.level - 1] + music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample] + missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1] + if missing_cond_len > 0: + init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device) + music_tokens_cond = torch.cat((music_tokens_cond, init_cond), dim=-1).long() + music_tokens_conds = [music_tokens_cond] + else: + music_tokens_conds = None + return music_tokens_conds + + def prior_preprocess(self, tokens, conds): + """ + Shifts the input tokens to account for the dictionnary merge. The embed_dim_shift give by how much the music + tokens should be shifted by. It is equal to `lyric_vocab_size`. + """ + batch_size = tokens[0].shape[0] + for i in range(len(tokens)): + tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1) + + for i in range(len(conds)): + if conds[i] is None: + conds[i] = torch.zeros( + (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device + ) + + return torch.cat(tokens, dim=1), torch.cat(conds, dim=1) + + def prior_postprocess(self, tokens): + """ + Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is + shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music + tokens. + """ + batch_size = tokens.shape[0] + dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0]) + tokens = list(torch.split(tokens, dims, dim=1)) + + # Some of the input tokens might be shifted to take into account the voccabulary fusion + for i in range(len(tokens)): + bins_shift = int(self.embed_dim_shift[i]) + tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1) + tokens[i] = torch.clamp(tokens[i], min=0) + # If not masking loss, model may have generated lyric/midi tokens which are now shifted <0 by bin_shift + return tokens[-1] + + def embed_tokens(self, music_tokens_conds): + """ + Embeds the upper level music tokens and upsamples them to provide as audio conditioning. + """ + music_tokens_conds = music_tokens_conds[: self.cond_level + 1] + audio_conditioning = None + for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))): + audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning) + return audio_conditioning + + def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1): + """ + Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states. + """ + if start_level is None: + start_level = self.level + if end_level is None: + end_level = self.levels + # Get latents + with torch.no_grad(): + latent_states = self.vqvae_encoder( + hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks + ) + return latent_states + + def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1): + """ + Usamples the sequence of codebook vectors to a raw audio. + """ + if start_level is None: + start_level = self.level + if end_level is None: + end_level = self.levels + with torch.no_grad(): + output = self.vqvae_decoder( + music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks + ) + return output + + def get_cond(self, music_tokens_conds, metadata): + """ + Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens + can be None. + """ + if metadata is not None: + n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens + metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:] + else: + metadata, lyric_tokens = None, None + metadata_conditioning, metadata_pos = ( + self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None) + ) + audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos + return audio_conditioning, metadata_conditioning, lyric_tokens + + def sample( + self, + n_samples, + music_tokens=None, + music_tokens_conds=None, + metadata=None, + temp=1.0, + top_k=0, + top_p=0.0, + chunk_size=None, + sample_tokens=None, + ): + """ + Ancestral/Prime sampling a window of tokens using the provided conditioning and metadatas. + + Args: + n_samples (`int`): + Number of samples to generate. + music_tokens (`List[torch.LongTensor]`, *optional*): + Previously gemerated tokens at the current level. Used as context for the generation. + music_tokens_conds (`List[torch.FloatTensor]`, *optional*): + Upper-level music tokens generated by the previous prior model. Is `None` if the generation is not + conditionned on the upper-level tokens. + metadata (`List[torch.LongTensor]`, *optional*): + List containing the metatdata tensor with the artist, genre and the lyric tokens. + temp (`float`, *optional*, defaults to 1.0): + Sampling temperature. + top_k (`int`, *optional*, defaults to 0): + Top k probabilities used for filtering. + top_p (`float`, *optional*, defaults to 0.0): + Top p probabilities used for filtering. + chunk_size (`int`, *optional*): + Size of the chunks used to prepare the cache of the transformer. + sample_tokens (`int`, *optional*): + Number of tokens to sample. + + """ + no_past_context = music_tokens is None or music_tokens.shape[1] == 0 + name = {True: "Ancestral", False: "Primed"}[no_past_context] + logger.info(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}") + + with torch.no_grad(): + # Currently audio_conditioning only uses immediately above layer + audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata) + if self.is_encoder_decoder: + if no_past_context: # the prime_sample function will be used with music_tokens set to None + lyric_and_music_tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens], [None, audio_conditioning] + ) + else: + lyric_and_music_tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens, music_tokens], [None, audio_conditioning] + ) + if sample_tokens is not None: + sample_tokens += self.nb_relevant_lyric_tokens + music_tokens = self.prior.primed_sample( + n_samples, + lyric_and_music_tokens, + audio_conditioning, + metadata_conditioning, + temp=temp, + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + music_tokens = self.prior_postprocess(music_tokens) + else: + last_encoder_hidden_states = self.get_encoder_states(lyric_tokens, sample=True) + if no_past_context: + music_tokens = self.prior.sample( + n_samples, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + temp=temp, + top_k=top_k, + top_p=top_p, + sample_tokens=sample_tokens, + ) + else: + music_tokens = self.prior.primed_sample( + n_samples, + music_tokens, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + temp=temp, + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + return music_tokens + + def get_encoder_states(self, lyric_tokens, sample=False): + """ + Retreive the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through + the lyric encoder. + """ + if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning: + if sample: + self.encoder = self.encoder.to(lyric_tokens.device) + lyric_acts = self.encoder(lyric_tokens, None, None, None) + lyric_acts = self.encoder.proj_in(lyric_acts) + last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts) + else: + last_encoder_hidden_states = None + return last_encoder_hidden_states + + def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics): + """ + Computes the loss for the lyric encoder: next lyric token prediction. + """ + if self.lyric_conditioning: + last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states) + encoder_loss = nn.functional.cross_entropy( + last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1) + ) / np.log(2.0) + else: + encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device) + return encoder_loss + + def forward_tokens( + self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False + ): + """ + Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the + vqvae's encoding layers. + """ + if get_attn_weights: + self.prior.transformer.set_record_attn(get_attn_weights) + audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata) + + if self.is_encoder_decoder: # the preprocess returns the full tokens (Lyrics and Music tokens), shifted + tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens, music_tokens], [None, audio_conditioning] + ) + (encoder_loss, next_token_prediction_loss), preds = self.prior( + tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds + ) + else: + last_encoder_hidden_states = self.get_encoder_states(lyric_tokens) + encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens) + next_token_prediction_loss, preds = self.prior( + music_tokens, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + get_preds=get_preds, + ) + loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims + loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims + + metrics = dict( + bpd=next_token_prediction_loss.clone().detach(), + encoder_loss=encoder_loss.clone().detach(), + next_token_prediction_loss=next_token_prediction_loss.clone().detach(), + ) + if get_preds: + metrics["preds"] = preds.clone().detach() + if get_attn_weights: + saved_attn_weights = self.prior.transformer.saved_attn_weights + self.prior.transformer.set_record_attn(False) + return saved_attn_weights + else: + return loss, metrics + + def forward(self, hidden_states, metadata=None, decode=False, get_preds=False): + """ + Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens` + function. The loss is the sum of the `encoder` loss and the `decoder` loss. + + Args: + hidden_states (`torch.Tensor`): + Hidden states which should be raw audio + metadata (`List[torch.LongTensor]`, *optional*): + List containing the metadata conditioning tensorwith the lyric and the metadata tokens. + decode (`bool`, *optional*, defaults to `False`): + Whether or not to decode the encoded to tokens. + get_preds (`bool`, *optional*, defaults to `False`): + Whether or not to return the actual predicitons of the model. + """ + batch_size = hidden_states.shape[0] + music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size) + loss, metrics = self.forward_tokens( + music_tokens=music_tokens, + music_tokens_conds=music_tokens_conds, + metadata=metadata, + get_preds=get_preds, + ) + if decode: + dequantised_states = self.decode([music_tokens, *music_tokens_conds]) + else: + dequantised_states = None + return dequantised_states, loss, metrics + + +class JukeboxPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = JukeboxConfig + base_model_prefix = "jukebox" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE): + module.apply(module._init_weights) + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + +JUKEBOX_SAMPLING_INPUT_DOCSTRING = r""" + labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` : + List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to + condition the generation. + sampling_kwargs (`Dict[Any]`): + Various additional sampling arguments that are used by the `_sample` function. A detail list of the + arguments can bee seen in the [`_sample`] function documentation. +""" + + +@add_start_docstrings( + """The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`, + `continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If + you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior + individually. + """, + JUKEBOX_START_DOCSTRING, +) +class JukeboxModel(JukeboxPreTrainedModel): + _no_split_modules = ["JukeboxBlock"] + + def __init__(self, config): + super().__init__(config) + vqvae_config = config.vqvae_config + self.vqvae = JukeboxVQVAE(vqvae_config) + self.set_shared_params(config) + self.priors = nn.ModuleList( + [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)] + ) + + def set_shared_params(self, model_config): + """ + Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig` + is nest, and is thus unreachable in the `from_dict` function + """ + for config in model_config.prior_configs: + config.sampling_rate = model_config.sampling_rate + config.timing_dims = model_config.timing_dims + config.min_duration = model_config.min_duration + config.max_duration = model_config.max_duration + config.max_nb_genres = model_config.max_nb_genres + config.metadata_conditioning = model_config.metadata_conditioning + + def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1): + return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks) + + def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1): + return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks) + + def split_batch(self, obj, n_samples, split_size): + n_passes = (n_samples + split_size - 1) // split_size + if isinstance(obj, torch.Tensor): + return torch.split(obj, split_size, dim=0) + elif isinstance(obj, list): + return list(zip(*[torch.split(item, split_size, dim=0) for item in obj])) + elif obj is None: + return [None] * n_passes + else: + raise TypeError("Unknown input type") + + # Sample a partial window of length= self.priors[level].n_ctx: + iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length) + for start in iterator: + music_tokens = self.sample_single_window( + music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size + ) + + else: + music_tokens = self.sample_partial_window( + music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size + ) + return music_tokens + + @torch.no_grad() + def _sample( + self, + music_tokens, + labels, + sample_levels, + metas=None, + chunk_size=32, + sampling_temperature=0.98, + lower_batch_size=16, + max_batch_size=16, + sample_length_in_seconds=24, + compute_alignments=False, + sample_tokens=None, + offset=0, + save_results=True, + sample_length=None, + ) -> List[torch.LongTensor]: + """ + Core sampling function used to generate music tokens. Iterates over the provided list of levels, while saving + the generated raw audio at each step. + + Args: + music_tokens (`List[torch.LongTensor]`): + A sequence of music tokens of length `self.levels` which will be used as context to continue the + sampling process. Should have `self.levels` tensors, each corresponding to the generation at a certain + level. + labels (`List[torch.LongTensor]`): + List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre + + lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens + which are used to condition the generation. + sample_levels (`List[int]`): + List of the desired levels at which the sampling will be done. A level is equivalent to the index of + the prior in the list of priors + metas (`List[Any]`, *optional*): + Metadatas used to generate the `labels` + chunk_size (`int`, *optional*, defaults to 32): + Size of a chunk of audio, used to fill up the memory in chuncks to prevent OOM erros. Bigger chunks + means faster memory filling but more consumption. + sampling_temperature (`float`, *optional*, defaults to 0.98): + Temperature used to ajust the randomness of the sampling. + lower_batch_size (`int`, *optional*, defaults to 16): + Maximum batch size for the lower level priors + max_batch_size (`int`, *optional*, defaults to 16): + Maximum batch size for the top level priors + sample_length_in_seconds (`int`, *optional*, defaults to 24): + Desired length of the generation in seconds + compute_alignments (`bool`, *optional*, defaults to `False`): + Whether or not to compute the alignment between the lyrics and the audio using the top_prior + sample_tokens (`int`, *optional*): + Precise number of tokens that should be sampled at each level. This is mostly useful for running dummy + experiments + offset (`int`, *optional*, defaults to 0): + Audio offset used as conditioning, corresponds to the starting sample in the music. If the offset is + greater than 0, the lyrics will be shifted take that intoaccount + save_results (`bool`, *optional*, defaults to `True`): + Whether or not to save the intermediate results. If `True`, will generate a folder named with the start + time. + sample_length (`int`, *optional*): + Desired length of the generation in samples. + + Returns: torch.Tensor + + Example: + + ```python + >>> from transformers import JukeboxTokenizer, JukeboxModel, set_seed + >>> import torch + + >>> metas = dict(artist="Zac Brown Band", genres="Country", lyrics="I met a traveller from an antique land") + >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval() + + >>> labels = tokenizer(**metas)["input_ids"] + >>> set_seed(0) + >>> zs = [torch.zeros(1, 0, dtype=torch.long) for _ in range(3)] + >>> zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False) + >>> zs[0] + tensor([[1853, 1369, 1150, 1869, 1379, 1789, 519, 710, 1306, 1100, 1229, 519, + 353, 1306, 1379, 1053, 519, 653, 1631, 1467, 1229, 1229, 10, 1647, + 1254, 1229, 1306, 1528, 1789, 216, 1631, 1434, 653, 475, 1150, 1528, + 1804, 541, 1804, 1434]]) + ``` + """ + + top_prior = self.priors[0] + if sample_length is not None: + total_length = sample_length + else: + total_length = ( + int(sample_length_in_seconds * self.config.sampling_rate) // top_prior.raw_to_tokens + ) * top_prior.raw_to_tokens + + if sample_levels is None: + sample_levels = range(len(self.priors)) + + # total length of the signal, might be bit different from the actual generated length + self.total_length = total_length + for level in sample_levels: + sampling_kwargs = dict( + temp=0.99 if level == len(self.priors) - 1 else sampling_temperature, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + # Set correct total_length, hop_length, labels and sampling_kwargs for level + + total_token_to_sample = total_length // self.priors[level].raw_to_tokens + hop_length = int(self.config.hop_fraction[level] * self.priors[level].n_ctx) + max_batch_size = lower_batch_size if level != sample_levels else max_batch_size + music_tokens = self.sample_level( + music_tokens, + labels[level], + offset, + sampling_kwargs, + level, + total_token_to_sample, + hop_length, + max_batch_size, + ) + + if save_results: + self.vqvae.to(music_tokens[level].device) + # Decode sample + with torch.no_grad(): + start_level = len(self.priors) - level - 1 # vqvae levels are reversed + raw_audio = self.vqvae.decode( + music_tokens[: level + 1], start_level=start_level, bs_chunks=music_tokens[level].shape[0] + ) + logdir = f"jukebox/level_{level}" + if not os.path.exists(logdir): + os.makedirs(logdir) + save_temp_audio(logdir, level, metas=metas, aud=raw_audio.float()) + if compute_alignments and self.priors[0] is not None and self.priors[0].nb_relevant_lyric_tokens > 0: + with torch.no_grad(): + alignments = get_alignment(music_tokens, labels[0], self.priors[0], self.config) + torch.save({"alignments": alignments}, f"{logdir}/lyric_alignments.pt") + + return music_tokens + + @add_start_docstrings( + """ + Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically + upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use + the VQ-VAE decoder to convert the music tokens to raw audio. + + Args: + labels (`List[torch.LongTensor]`) : + List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre + + lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens + which are used to condition the generation. + n_samples (`int`, *optional*, default to 1) : + Number of samples to be generated in parallel. + """, + ) + def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]: + """ + Example: + + ```python + >>> from transformers import JukeboxTokenizer, JukeboxModel, set_seed + + >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval() + >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + + >>> lyrics = "Hey, are you awake? Can you talk to me?" + >>> artist = "Zac Brown Band" + >>> genre = "Country" + >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics) + >>> set_seed(0) + >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400) + + >>> with torch.no_grad(): + ... model.decode(music_tokens)[:, :10].squeeze(-1) + tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405, + -0.0818, -0.0697]]) + ``` + """ + + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + music_tokens = [ + torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors)) + ] + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + @add_start_docstrings( + """Generates a continuation of the previously generated tokens. + + Args: + music_tokens (`List[torch.LongTensor]` of length `self.levels` ) : + A sequence of music tokens which will be used as context to continue the sampling process. Should have + `self.levels` tensors, each corresponding to the generation at a certain level. + """, + JUKEBOX_SAMPLING_INPUT_DOCSTRING, + ) + def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + @add_start_docstrings( + """Upsamples a sequence of music tokens using the prior at level `level`. + + Args: + music_tokens (`List[torch.LongTensor]` of length `self.levels` ) : + A sequence of music tokens which will be used as context to continue the sampling process. Should have + `self.levels` tensors, each corresponding to the generation at a certain level. + """, + JUKEBOX_SAMPLING_INPUT_DOCSTRING, + ) + def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1))) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + @add_start_docstrings( + """Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the + generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are + used: as conditioning for each level, which means that no ancestral sampling is required. + + Args: + raw_audio (`List[torch.Tensor]` of length `n_samples` ) : + A list of raw audio that will be used as conditioning information for each samples that will be + generated. + """, + JUKEBOX_SAMPLING_INPUT_DOCSTRING, + ) + def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + self.vqvae.to(raw_audio.device).float() + with torch.no_grad(): + music_tokens = self.vqvae.encode( + raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0] + ) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens diff --git a/src/transformers/models/jukebox/tokenization_jukebox.py b/src/transformers/models/jukebox/tokenization_jukebox.py new file mode 100644 index 0000000000000..01bada0e0806b --- /dev/null +++ b/src/transformers/models/jukebox/tokenization_jukebox.py @@ -0,0 +1,424 @@ +# coding=utf-8 +# Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI Jukebox.""" + + +import json +import os +import re +import unicodedata +from json.encoder import INFINITY +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +import regex +from transformers.utils.generic import _is_jax, _is_numpy + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "artists_file": "artists.json", + "lyrics_file": "lyrics.json", + "genres_file": "genres.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "artists_file": { + "jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/artists.json", + }, + "genres_file": { + "jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/genres.json", + }, + "lyrics_file": { + "jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/lyrics.json", + }, +} + +PRETRAINED_LYRIC_TOKENS_SIZES = { + "jukebox": 512, +} + + +class JukeboxTokenizer(PreTrainedTokenizer): + """ + Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs : + - Artists, unique ids are associated to each artist from the provided dictionary. + - Genres, unique ids are associated to each genre from the provided dictionary. + - Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the + vocabulary. + + This tokenizer does not require training. It should be able to process a different number of inputs: + as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.: + + Depending on the number of genres on which the model should be conditioned (`n_genres`). + ``` + >>> from transformers import JukeboxTokenizer + >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + >>> tokenizer("Alan Jackson", "Country Rock", "old town road")['input_ids'] + [tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49, + 40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])] + + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + If nothing is provided, the genres and the artist will either be selected randomly or set to None + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to: + this superclass for more information regarding those methods. + + However the code does not allow that and only supports composing from various genres. + + Args: + artists_file (`str`): + Path to the vocabulary file which contains a mapping between artists and ids. The default file supports + both "v2" and "v3" + genres_file (`str`): + Path to the vocabulary file which contain a mapping between genres and ids. + lyrics_file (`str`): + Path to the vocabulary file which contains the accepted characters for the lyrics tokenization. + version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) : + List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of + `v2`. + n_genres (`int`, `optional`, defaults to 1): + Maximum number of genres to use for composition. + max_n_lyric_tokens (`int`, `optional`, defaults to 512): + Maximum number of lyric tokens to keep. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_lyric_input_size = PRETRAINED_LYRIC_TOKENS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + artists_file, + genres_file, + lyrics_file, + version=["v3", "v2", "v2"], + max_n_lyric_tokens=512, + n_genres=5, + unk_token="<|endoftext|>", + **kwargs + ): + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + super().__init__( + unk_token=unk_token, + n_genres=n_genres, + version=version, + max_n_lyric_tokens=max_n_lyric_tokens, + **kwargs, + ) + self.version = version + self.max_n_lyric_tokens = max_n_lyric_tokens + self.n_genres = n_genres + + with open(artists_file, encoding="utf-8") as vocab_handle: + self.artists_encoder = json.load(vocab_handle) + + with open(genres_file, encoding="utf-8") as vocab_handle: + self.genres_encoder = json.load(vocab_handle) + + with open(lyrics_file, encoding="utf-8") as vocab_handle: + self.lyrics_encoder = json.load(vocab_handle) + + oov = "[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+" + # In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters. + if len(self.lyrics_encoder) == 79: + oov = oov.replace("\-'", "\-+'") + + self.out_of_vocab = regex.compile(oov) + self.artists_decoder = {v: k for k, v in self.artists_encoder.items()} + self.genres_decoder = {v: k for k, v in self.genres_encoder.items()} + self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()} + + @property + def vocab_size(self): + return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder) + + def get_vocab(self): + return dict(self.artists_encoder, self.genres_encoder, self.lyrics_encoder) + + def _convert_token_to_id(self, list_artists, list_genres, list_lyrics): + """Converts the artist, genre and lyrics tokens to their index using the vocabulary. + The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to + the lyrics token sequence. + """ + artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists] + for genres in range(len(list_genres)): + list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]] + list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres])) + + lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []] + return artists_id, list_genres, lyric_ids + + def _tokenize(self, lyrics): + """ + Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary. + """ + # only lyrics are not tokenized, but character based is easily handled + return [character for character in lyrics] + + def tokenize(self, artist, genre, lyrics, **kwargs): + """ + Converts three strings in a 3 sequence of tokens using the tokenizer + """ + artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics) + lyrics = self._tokenize(lyrics) + return artist, genre, lyrics + + def prepare_for_tokenization( + self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False + ) -> Tuple[str, str, str, Dict[str, Any]]: + """ + Performs any necessary transformations before tokenization. + + This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the + `kwargs` at the end of the encoding process to be sure all the arguments have been used. + + Args: + artist (`str`): + The artist name to prepare. This will mostly lower the string + genres (`str`): + The genre name to prepare. This will mostly lower the string. + lyrics (`str`): + The lyrics to prepare. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + kwargs: + Keyword arguments to use for the tokenization. + """ + for idx in range(len(self.version)): + if self.version[idx] == "v3": + artists[idx] = artists[idx].lower() + genres[idx] = [genres[idx].lower()] + else: + artists[idx] = self._normalize(artists[idx]) + ".v2" + genres[idx] = [ + self._normalize(genre) + ".v2" for genre in genres[idx].split("_") + ] # split is for the full dictionary with combined genres + + if self.version[0] == "v2": + self.out_of_vocab = regex.compile("[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+") + vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n" + self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))} + self.vocab[""] = 0 + self.n_vocab = len(vocab) + 1 + self.lyrics_encoder = self.vocab + self.lyrics_decoder = {v: k for k, v in self.vocab.items()} + self.lyrics_decoder[0] = "" + else: + self.out_of_vocab = regex.compile("[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+") + + lyrics = self._run_strip_accents(lyrics) + lyrics = lyrics.replace("\\", "\n") + lyrics = self.out_of_vocab.sub("", lyrics), [], [] + return artists, genres, lyrics + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _normalize(self, text: str) -> str: + """ + Normalizes the input text. This process is for the genres and the artist + + Args: + text (`str`): + Artist or Genre string to normalize + """ + + accepted = ( + [chr(i) for i in range(ord("a"), ord("z") + 1)] + + [chr(i) for i in range(ord("A"), ord("Z") + 1)] + + [chr(i) for i in range(ord("0"), ord("9") + 1)] + + ["."] + ) + accepted = frozenset(accepted) + pattern = re.compile(r"_+") + text = "".join([c if c in accepted else "_" for c in text.lower()]) + text = pattern.sub("_", text).strip("_") + return text + + def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str: + return " ".join(lyrics) + + def convert_to_tensors( + self, inputs, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False + ): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + unset, no modification is done. + prepend_batch_axis (`int`, *optional*, defaults to `False`): + Whether or not to add the batch dimension during the conversion. + """ + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + as_tensor = tf.constant + is_tensor = tf.is_tensor + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch + + as_tensor = torch.tensor + is_tensor = torch.is_tensor + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + as_tensor = jnp.array + is_tensor = _is_jax + else: + as_tensor = np.asarray + is_tensor = _is_numpy + + # Do the tensor conversion in batch + + try: + if prepend_batch_axis: + inputs = [inputs] + + if not is_tensor(inputs): + inputs = as_tensor(inputs) + except: # noqa E722 + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=True' 'truncation=True' to have batched tensors with the same length." + ) + + return inputs + + def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding: + """Convert the raw string to a list of token ids + + Args: + artist (`str`): + Name of the artist. + genres (`str`): + List of genres that will be mixed to condition the audio + lyrics (`str`, *optional*, defaults to `""`): + Lyrics used to condition the generation + """ + input_ids = [0, 0, 0] + artist = [artist] * len(self.version) + genres = [genres] * len(self.version) + + artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics) + artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens) + + attention_masks = [-INFINITY] * len(full_tokens[-1]) + input_ids = [ + self.convert_to_tensors( + [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors + ) + for i in range(len(self.version)) + ] + return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks}) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Saves the tokenizer's vocabulary dictionary to the provided save_directory. + + Args: + save_directory (`str`): + A path to the directory where to saved. It will be created if it doesn't exist. + + filename_prefix (`Optional[str]`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + artists_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"] + ) + with open(artists_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.artists_encoder, ensure_ascii=False)) + + genres_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"] + ) + with open(genres_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.genres_encoder, ensure_ascii=False)) + + lyrics_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"] + ) + with open(lyrics_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False)) + + return (artists_file, genres_file, lyrics_file) + + def _convert_id_to_token(self, artists_index, genres_index, lyric_index): + """ + Converts an index (integer) in a token (str) using the vocab. + + Args: + artists_index (`int`): + Index of the artist in its corresponding dictionary. + genres_index (`Union[List[int], int]`): + Index of the genre in its corresponding dictionary. + lyric_index (`List[int]`): + List of character indices, which each correspond to a character. + """ + artist = self.artists_decoder.get(artists_index) + genres = [self.genres_decoder.get(genre) for genre in genres_index] + lyrics = [self.lyrics_decoder.get(character) for character in lyric_index] + return artist, genres, lyrics diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 08db45a62ab0b..d8702ef8d403c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2704,6 +2704,37 @@ def load_tf_weights_in_imagegpt(*args, **kwargs): requires_backends(load_tf_weights_in_imagegpt, ["torch"]) +JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class JukeboxModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class JukeboxPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class JukeboxPrior(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class JukeboxVQVAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/jukebox/__init__.py b/tests/models/jukebox/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/jukebox/test_modeling_jukebox.py b/tests/models/jukebox/test_modeling_jukebox.py new file mode 100644 index 0000000000000..9232119432f5a --- /dev/null +++ b/tests/models/jukebox/test_modeling_jukebox.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, slow +from transformers.trainer_utils import set_seed + + +if is_torch_available(): + import torch + + from transformers import JukeboxModel, JukeboxTokenizer + + +@require_torch +class Jukebox1bModelTester(unittest.TestCase): + all_model_classes = (JukeboxModel,) if is_torch_available() else () + model_id = "openai/jukebox-1b-lyrics" + metas = dict( + artist="Zac Brown Band", + genres="Country", + lyrics="""I met a traveller from an antique land, + Who said "Two vast and trunkless legs of stone + Stand in the desert. . . . Near them, on the sand, + Half sunk a shattered visage lies, whose frown, + And wrinkled lip, and sneer of cold command, + Tell that its sculptor well those passions read + Which yet survive, stamped on these lifeless things, + The hand that mocked them, and the heart that fed; + And on the pedestal, these words appear: + My name is Ozymandias, King of Kings; + Look on my Works, ye Mighty, and despair! + Nothing beside remains. Round the decay + Of that colossal Wreck, boundless and bare + The lone and level sands stretch far away + """, + ) + # fmt: off + EXPECTED_OUTPUT_2 = [ + 1864, 1536, 1213, 1870, 1357, 1536, 519, 880, 1323, 789, 1082, 534, + 1000, 1445, 1105, 1130, 967, 515, 1434, 1620, 534, 1495, 283, 1445, + 333, 1307, 539, 1631, 1528, 375, 1434, 673, 627, 710, 778, 1883, + 1405, 1276, 1455, 1228 + ] + + EXPECTED_OUTPUT_1 = [ + 1125, 1751, 697, 1776, 1141, 1476, 391, 697, 1125, 684, 867, 416, + 844, 1372, 1274, 717, 1274, 844, 1299, 1419, 697, 1370, 317, 1125, + 191, 1440, 1370, 1440, 1370, 282, 1621, 1370, 368, 349, 867, 1872, + 1262, 869, 1728, 747 + ] + + EXPECTED_OUTPUT_0 = [ + 1755, 842, 307, 1843, 1022, 1395, 234, 1554, 806, 739, 1022, 442, + 616, 556, 268, 1499, 933, 457, 1440, 1837, 755, 985, 308, 902, + 293, 1443, 1671, 1141, 1533, 555, 1562, 1061, 287, 417, 1022, 2008, + 1186, 1015, 1777, 268 + ] + + EXPECTED_Y_COND = [1058304, 0, 786432, 7169, 507, 76, 27, 40, 30, 76] + + EXPECTED_PRIMED_0 = [ + 390, 1160, 1002, 1907, 1788, 1788, 1788, 1907, 1002, 1002, 1854, 1002, + 1002, 1002, 1002, 1002, 1002, 1160, 1160, 1606, 596, 596, 1160, 1002, + 1516, 596, 1002, 1002, 1002, 1907, 1788, 1788, 1788, 1854, 1788, 1907, + 1907, 1788, 596, 1626 + ] + EXPECTED_PRIMED_1 = [ + 1236, 1668, 1484, 1920, 1848, 1409, 139, 864, 1828, 1272, 1599, 824, + 1672, 139, 555, 1484, 824, 1920, 555, 596, 1579, 1599, 1231, 1599, + 1637, 1407, 212, 824, 1599, 116, 1433, 824, 258, 1599, 1433, 1895, + 1063, 1433, 1433, 1599 + ] + EXPECTED_PRIMED_2 = [ + 1684, 1873, 1119, 1189, 395, 611, 1901, 972, 890, 1337, 1392, 1927, + 96, 972, 672, 780, 1119, 890, 158, 771, 1073, 1927, 353, 1331, + 1269, 1459, 1333, 1645, 812, 1577, 1337, 606, 353, 981, 1466, 619, + 197, 391, 302, 1930 + ] + EXPECTED_VQVAE_ENCODE = [ + 390, 1160, 1002, 1907, 1788, 1788, 1788, 1907, 1002, 1002, 1854, 1002, + 1002, 1002, 1002, 1002, 1002, 1160, 1160, 1606, 596, 596, 1160, 1002, + 1516, 596, 1002, 1002, 1002, 1907, 1788, 1788, 1788, 1854, 1788, 1907, + 1907, 1788, 596, 1626 + ] + EXPECTED_VQVAE_DECODE = [ + -0.0492, -0.0524, -0.0565, -0.0640, -0.0686, -0.0684, -0.0677, -0.0664, + -0.0605, -0.0490, -0.0330, -0.0168, -0.0083, -0.0075, -0.0051, 0.0025, + 0.0136, 0.0261, 0.0386, 0.0497, 0.0580, 0.0599, 0.0583, 0.0614, + 0.0740, 0.0889, 0.1023, 0.1162, 0.1211, 0.1212, 0.1251, 0.1336, + 0.1502, 0.1686, 0.1883, 0.2148, 0.2363, 0.2458, 0.2507, 0.2531 + ] + EXPECTED_AUDIO_COND = [ + 0.0256, -0.0544, 0.1600, -0.0032, 0.1066, 0.0825, -0.0013, 0.3440, + 0.0210, 0.0412, -0.1777, -0.0892, -0.0164, 0.0285, -0.0613, -0.0617, + -0.0137, -0.0201, -0.0175, 0.0215, -0.0627, 0.0520, -0.0730, 0.0970, + -0.0100, 0.0442, -0.0586, 0.0207, -0.0015, -0.0082 + ] + EXPECTED_META_COND = [ + 0.0415, 0.0877, 0.0022, -0.0055, 0.0751, 0.0334, 0.0324, -0.0068, + 0.0011, 0.0017, -0.0676, 0.0655, -0.0143, 0.0399, 0.0303, 0.0743, + -0.0168, -0.0394, -0.1113, 0.0124, 0.0442, 0.0267, -0.0003, -0.1536, + -0.0116, -0.1837, -0.0180, -0.1026, -0.0777, -0.0456 + ] + EXPECTED_LYRIC_COND = [ + 76, 27, 40, 30, 76, 46, 44, 47, 40, 37, 38, 31, 45, 45, 76, 38, 31, 33, + 45, 76, 41, 32, 76, 45, 46, 41, 40, 31, 78, 76 + ] + # fmt: on + + def prepare_inputs(self): + tokenizer = JukeboxTokenizer.from_pretrained(self.model_id) + tokens = tokenizer(**self.metas)["input_ids"] + return tokens + + @slow + def test_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() + labels = self.prepare_inputs() + + set_seed(0) + zs = [torch.zeros(1, 0, dtype=torch.long).cpu() for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[0][0], torch.tensor(self.EXPECTED_OUTPUT_2)) + + set_seed(0) + zs = model._sample(zs, labels, [1], sample_length=40 * model.priors[1].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[1][0], torch.tensor(self.EXPECTED_OUTPUT_1)) + + set_seed(0) + zs = model._sample(zs, labels, [2], sample_length=40 * model.priors[2].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[2][0], torch.tensor(self.EXPECTED_OUTPUT_0)) + + @slow + def test_conditioning(self): + torch.backends.cuda.matmul.allow_tf32 = False + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() + + labels = self.prepare_inputs() + set_seed(0) + zs = [torch.zeros(1, 0, dtype=torch.long) for _ in range(3)] + + top_prior = model.priors[0] + start = 0 + music_token_conds = top_prior.get_music_tokens_conds(zs, start=start, end=start + top_prior.n_ctx) + metadata = top_prior.get_metadata(labels[0].clone(), start, 1058304, 0) + + self.assertIsNone(music_token_conds) + self.assertListEqual(metadata.numpy()[0][:10].tolist(), self.EXPECTED_Y_COND) + + audio_conditioning, metadata_conditioning, lyric_tokens = top_prior.get_cond(music_token_conds, metadata) + torch.testing.assert_allclose( + audio_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_AUDIO_COND), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_allclose( + metadata_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_META_COND), atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_allclose( + lyric_tokens[0, :30].detach(), torch.tensor(self.EXPECTED_LYRIC_COND), atol=1e-4, rtol=1e-4 + ) + + @slow + def test_primed_sampling(self): + torch.backends.cuda.matmul.allow_tf32 = False + + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() + set_seed(0) + waveform = torch.rand((1, 5120, 1)) + tokens = [i for i in self.prepare_inputs()] + + zs = [model.vqvae.encode(waveform, start_level=2, bs_chunks=waveform.shape[0])[0], None, None] + zs = model._sample( + zs, tokens, sample_levels=[0], save_results=False, sample_length=40 * model.priors[0].raw_to_tokens + ) + torch.testing.assert_allclose(zs[0][0][:40], torch.tensor(self.EXPECTED_PRIMED_0)) + + upper_2 = torch.cat((zs[0], torch.zeros(1, 2048 - zs[0].shape[-1])), dim=-1).long() + zs = [upper_2, model.vqvae.encode(waveform, start_level=1, bs_chunks=waveform.shape[0])[0], None] + zs = model._sample( + zs, tokens, sample_levels=[1], save_results=False, sample_length=40 * model.priors[1].raw_to_tokens + ) + torch.testing.assert_allclose(zs[1][0][:40], torch.tensor(self.EXPECTED_PRIMED_1)) + + upper_1 = torch.cat((zs[1], torch.zeros(1, 2048 - zs[1].shape[-1])), dim=-1).long() + zs = [upper_2, upper_1, model.vqvae.encode(waveform, start_level=0, bs_chunks=waveform.shape[0])[0]] + zs = model._sample( + zs, tokens, sample_levels=[2], save_results=False, sample_length=40 * model.priors[2].raw_to_tokens + ) + torch.testing.assert_allclose(zs[2][0][:40].cpu(), torch.tensor(self.EXPECTED_PRIMED_2)) + + @slow + def test_vqvae(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() + set_seed(0) + x = torch.rand((1, 5120, 1)) + with torch.no_grad(): + zs = model.vqvae.encode(x, start_level=2, bs_chunks=x.shape[0]) + torch.testing.assert_allclose(zs[0][0], torch.tensor(self.EXPECTED_VQVAE_ENCODE)) + + with torch.no_grad(): + x = model.vqvae.decode(zs, start_level=2, bs_chunks=x.shape[0]) + torch.testing.assert_allclose(x[0, :40, 0], torch.tensor(self.EXPECTED_VQVAE_DECODE), atol=1e-4, rtol=1e-4) + + +@require_torch +class Jukebox5bModelTester(unittest.TestCase): + all_model_classes = (JukeboxModel,) if is_torch_available() else () + model_id = "openai/jukebox-5b-lyrics" + metas = dict( + artist="Zac Brown Band", + genres="Country", + lyrics="""I met a traveller from an antique land, + Who said "Two vast and trunkless legs of stone + Stand in the desert. . . . Near them, on the sand, + Half sunk a shattered visage lies, whose frown, + And wrinkled lip, and sneer of cold command, + Tell that its sculptor well those passions read + Which yet survive, stamped on these lifeless things, + The hand that mocked them, and the heart that fed; + And on the pedestal, these words appear: + My name is Ozymandias, King of Kings; + Look on my Works, ye Mighty, and despair! + Nothing beside remains. Round the decay + Of that colossal Wreck, boundless and bare + The lone and level sands stretch far away + """, + ) + + # fmt: off + EXPECTED_OUTPUT_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 1489, 1489, 1489, 1489, 1150, 1853, 1509, 1150, 1357, 1509, 6, 1272 + ] + + EXPECTED_OUTPUT_1 = [ + 1125, 416, 1125, 1125, 1125, 1125, 1125, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416 + ] + + EXPECTED_OUTPUT_0 = [ + 1755, 1061, 234, 1755, 1061, 1755, 185, 290, 307, 307, 616, 616, + 616, 616, 616, 616, 307, 290, 417, 1755, 234, 1755, 185, 290, + 290, 290, 307, 616, 616, 616, 616, 616, 290, 234, 234, 1755, + 234, 234, 1755, 234, 185, 185, 307, 616, 616, 616, 616, 290, + 1755, 1755, 1755, 234, 234, 1755, 1572, 290, 307, 616, 34, 616 + ] + + EXPECTED_GPU_OUTPUTS_2 = [ + 1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, + 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653 + ] + EXPECTED_GPU_OUTPUTS_1 = [ + 1125, 1125, 416, 1125, 1125, 416, 1125, 1125, 416, 416, 1125, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, + 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416 + ] + EXPECTED_GPU_OUTPUTS_0 = [ + 491, 1755, 34, 1613, 1755, 417, 992, 1613, 222, 842, 1353, 1613, + 844, 632, 185, 1613, 844, 632, 185, 1613, 185, 842, 677, 1613, + 185, 114, 1353, 1613, 307, 89, 844, 1613, 307, 1332, 234, 1979, + 307, 89, 1353, 616, 34, 842, 185, 842, 34, 842, 185, 842, + 307, 114, 185, 89, 34, 1268, 185, 89, 34, 842, 185, 89 + ] + # fmt: on + + def prepare_inputs(self, model_id): + tokenizer = JukeboxTokenizer.from_pretrained(model_id) + tokens = tokenizer(**self.metas)["input_ids"] + return tokens + + @slow + def test_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() + labels = self.prepare_inputs(self.model_id) + + set_seed(0) + zs = [torch.zeros(1, 0, dtype=torch.long).cpu() for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[0][0], torch.tensor(self.EXPECTED_OUTPUT_2)) + + set_seed(0) + zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[1][0], torch.tensor(self.EXPECTED_OUTPUT_1)) + + set_seed(0) + zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[2][0], torch.tensor(self.EXPECTED_OUTPUT_0)) + + @slow + def test_slow_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval().to("cuda") + labels = [i.cuda() for i in self.prepare_inputs(self.model_id)] + + set_seed(0) + model.priors[0].cuda() + zs = [torch.zeros(1, 0, dtype=torch.long).cuda() for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2)) + model.priors[0].cpu() + + set_seed(0) + model.priors[1].cuda() + zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[1][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_1)) + model.priors[1].cpu() + + set_seed(0) + model.priors[2].cuda() + zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[2][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_0)) + + @slow + def test_fp16_slow_sampling(self): + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval().half().to("cuda") + labels = [i.cuda() for i in self.prepare_inputs(self.model_id)] + + set_seed(0) + zs = [torch.zeros(1, 0, dtype=torch.long).cuda() for _ in range(3)] + zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) + torch.testing.assert_allclose(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2)) diff --git a/tests/models/jukebox/test_tokenization_jukebox.py b/tests/models/jukebox/test_tokenization_jukebox.py new file mode 100644 index 0000000000000..7ce2585bdd64b --- /dev/null +++ b/tests/models/jukebox/test_tokenization_jukebox.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import JukeboxTokenizer +from transformers.testing_utils import require_torch + + +class JukeboxTokenizationTest(unittest.TestCase): + tokenizer_class = JukeboxTokenizer + metas = dict( + artist="Zac Brown Band", + genres="Country", + lyrics="""I met a traveller from an antique land, + Who said "Two vast and trunkless legs of stone + Stand in the desert. . . . Near them, on the sand, + Half sunk a shattered visage lies, whose frown, + And wrinkled lip, and sneer of cold command, + Tell that its sculptor well those passions read + Which yet survive, stamped on these lifeless things, + The hand that mocked them, and the heart that fed; + And on the pedestal, these words appear: + My name is Ozymandias, King of Kings; + Look on my Works, ye Mighty, and despair! + Nothing beside remains. Round the decay + Of that colossal Wreck, boundless and bare + The lone and level sands stretch far away + """, + ) + + @require_torch + def test_1b_lyrics_tokenizer(self): + """ + how to run the same test with openAI + ... + """ + import torch + + tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + tokens = tokenizer(**self.metas)["input_ids"] + # fmt: off + EXPECTED_OUTPUT = [ + torch.tensor([[ + 0, 0, 0, 7169, 507, 9, 76, 39, 31, 46, 76, 27, + 76, 46, 44, 27, 48, 31, 38, 38, 31, 44, 76, 32, + 44, 41, 39, 76, 27, 40, 76, 27, 40, 46, 35, 43, + 47, 31, 76, 38, 27, 40, 30, 64, 78, 76, 76, 76, + 76, 76, 76, 76, 76, 23, 34, 41, 76, 45, 27, 35, + 30, 76, 71, 20, 49, 41, 76, 48, 27, 45, 46, 76, + 27, 40, 30, 76, 46, 44, 47, 40, 37, 38, 31, 45, + 45, 76, 38, 31, 33, 45, 76, 41, 32, 76, 45, 46, + 41, 40, 31, 78, 76, 76, 76, 76, 76, 76, 76, 76, + 19, 46, 27, 40, 30, 76, 35, 40, 76, 46, 34, 31, + 76, 30, 31, 45, 31, 44, 46, 63, 76, 63, 76, 63, + 76, 63, 76, 14, 31, 27, 44, 76, 46, 34, 31, 39, + 64, 76, 41, 40, 76, 46, 34, 31, 76, 45, 27, 40, + 30, 64, 78, 76, 76, 76, 76, 76, 76, 76, 76, 8, + 27, 38, 32, 76, 45, 47, 40, 37, 76, 27, 76, 45, + 34, 27, 46, 46, 31, 44, 31, 30, 76, 48, 35, 45, + 27, 33, 31, 76, 38, 35, 31, 45, 64, 76, 49, 34, + 41, 45, 31, 76, 32, 44, 41, 49, 40, 64, 78, 76, + 76, 76, 76, 76, 76, 76, 76, 1, 40, 30, 76, 49, + 44, 35, 40, 37, 38, 31, 30, 76, 38, 35, 42, 64, + 76, 27, 40, 30, 76, 45, 40, 31, 31, 44, 76, 41, + 32, 76, 29, 41, 38, 30, 76, 29, 41, 39, 39, 27, + 40, 30, 64, 78, 76, 76, 76, 76, 76, 76, 76, 76, + 20, 31, 38, 38, 76, 46, 34, 27, 46, 76, 35, 46, + 45, 76, 45, 29, 47, 38, 42, 46, 41, 44, 76, 49, + 31, 38, 38, 76, 46, 34, 41, 45, 31, 76, 42, 27, + 45, 45, 35, 41, 40, 45, 76, 44, 31, 27, 30, 78, + 76, 76, 76, 76, 76, 76, 76, 76, 23, 34, 35, 29, + 34, 76, 51, 31, 46, 76, 45, 47, 44, 48, 35, 48, + 31, 64, 76, 45, 46, 27, 39, 42, 31, 30, 76, 41, + 40, 76, 46, 34, 31, 45, 31, 76, 38, 35, 32, 31, + 38, 31, 45, 45, 76, 46, 34, 35, 40, 33, 45, 64, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 20, 34, 31, + 76, 34, 27, 40, 30, 76, 46, 34, 27, 46, 76, 39, + 41, 29, 37, 31, 30, 76, 46, 34, 31, 39, 64, 76, + 27, 40, 30, 76, 46, 34, 31, 76, 34, 31, 27, 44, + 46, 76, 46, 34, 27, 46, 76, 32, 31, 30, 66, 78, + 76, 76, 76, 76, 76, 76, 76, 76, 1, 40, 30, 76, + 41, 40, 76, 46, 34, 31, 76, 42, 31, 30, 31, 45, + 46, 27, 38, 64, 76, 46, 34, 31, 45, 31, 76, 49, + 41, 44, 30, 45, 76, 27, 42, 42, 31, 27, 44, 65, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 13, 51, 76, + 40, 27, 39, 31, 76, 35, 45, 76, 15, 52, 51, 39, + 27, 40, 30, 35, 27, 45, 64, 76, 11, 35, 40, 33, + 76, 41, 32, 76, 11, 35, 40, 33, 45, 66, 78, 76, + 76, 76, 76, 76, 76, 76, 76, 12, 41, 41, 37, 76, + 41, 40, 76, 39, 51, 76, 23, 41, 44, 37, 45, 64, + 76, 51, 31, 76, 13, 35, 33, 34, 46, 51, 64, 76, + 27, 40, 30, 76, 30, 31, 45, 42, 27, 35, 44, 67, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 14, 41, 46, + 34, 35, 40, 33, 76, 28, 31, 45, 35, 30, 31, 76, + 44, 31, 39, 27, 35, 40, 45, 63, 76, 18, 41, 47, + 40, 30, 76, 46, 34, 31, 76, 30, 31, 29, 27, 51, + 78, 76, 76, 76, 76, 76, 76, 76, 76, 15, 32, 76, + 46, 34, 27, 46, 76, 29, 41, 38, 41, 45, 45, 27, + 38, 76, 23, 44, 31, 29, 37, 64, 76, 28, 41, 47, + 40, 30, 38, 31, 45, 45, 76, 27, 40, 30, 76, 28, + 27, 44, 31, 78, 76, 76, 76, 76, 76, 76, 76, 76, + 20, 34, 31, 76, 38, 41, 40, 31, 76, 27, 40, 30, + 76, 38, 31, 48, 31, 38, 76, 45, 27, 40, 30, 45, + 76, 45, 46, 44, 31, 46, 29, 34, 76, 32, 27, 44, + 76, 27, 49, 27, 51, 78, 76, 76, 76, 76, 76, 76, + 76, 76]]), + torch.tensor([[0, 0, 0, 1069, 11]]), + torch.tensor([[0, 0, 0, 1069, 11]]), + ] + # fmt: on + self.assertTrue(torch.allclose(tokens[0], EXPECTED_OUTPUT[0])) + self.assertTrue(torch.allclose(tokens[1], EXPECTED_OUTPUT[1])) + self.assertTrue(torch.allclose(tokens[2], EXPECTED_OUTPUT[2])) + + @require_torch + def test_5b_lyrics_tokenizer(self): + """ + The outputs are similar that open AI but do not have the same format as this one is adapted to the HF integration. + """ + import torch + + tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-5b-lyrics") + tokens = tokenizer(**self.metas)["input_ids"] + # fmt: off + EXPECTED_OUTPUT = [ + torch.tensor([[ + 0, 0, 0, 1069, 11, -1, -1, -1, -1, 9, 77, 39, + 31, 46, 77, 27, 77, 46, 44, 27, 48, 31, 38, 38, + 31, 44, 77, 32, 44, 41, 39, 77, 27, 40, 77, 27, + 40, 46, 35, 43, 47, 31, 77, 38, 27, 40, 30, 64, + 79, 77, 77, 77, 77, 77, 77, 77, 77, 23, 34, 41, + 77, 45, 27, 35, 30, 77, 72, 20, 49, 41, 77, 48, + 27, 45, 46, 77, 27, 40, 30, 77, 46, 44, 47, 40, + 37, 38, 31, 45, 45, 77, 38, 31, 33, 45, 77, 41, + 32, 77, 45, 46, 41, 40, 31, 79, 77, 77, 77, 77, + 77, 77, 77, 77, 19, 46, 27, 40, 30, 77, 35, 40, + 77, 46, 34, 31, 77, 30, 31, 45, 31, 44, 46, 63, + 77, 63, 77, 63, 77, 63, 77, 14, 31, 27, 44, 77, + 46, 34, 31, 39, 64, 77, 41, 40, 77, 46, 34, 31, + 77, 45, 27, 40, 30, 64, 79, 77, 77, 77, 77, 77, + 77, 77, 77, 8, 27, 38, 32, 77, 45, 47, 40, 37, + 77, 27, 77, 45, 34, 27, 46, 46, 31, 44, 31, 30, + 77, 48, 35, 45, 27, 33, 31, 77, 38, 35, 31, 45, + 64, 77, 49, 34, 41, 45, 31, 77, 32, 44, 41, 49, + 40, 64, 79, 77, 77, 77, 77, 77, 77, 77, 77, 1, + 40, 30, 77, 49, 44, 35, 40, 37, 38, 31, 30, 77, + 38, 35, 42, 64, 77, 27, 40, 30, 77, 45, 40, 31, + 31, 44, 77, 41, 32, 77, 29, 41, 38, 30, 77, 29, + 41, 39, 39, 27, 40, 30, 64, 79, 77, 77, 77, 77, + 77, 77, 77, 77, 20, 31, 38, 38, 77, 46, 34, 27, + 46, 77, 35, 46, 45, 77, 45, 29, 47, 38, 42, 46, + 41, 44, 77, 49, 31, 38, 38, 77, 46, 34, 41, 45, + 31, 77, 42, 27, 45, 45, 35, 41, 40, 45, 77, 44, + 31, 27, 30, 79, 77, 77, 77, 77, 77, 77, 77, 77, + 23, 34, 35, 29, 34, 77, 51, 31, 46, 77, 45, 47, + 44, 48, 35, 48, 31, 64, 77, 45, 46, 27, 39, 42, + 31, 30, 77, 41, 40, 77, 46, 34, 31, 45, 31, 77, + 38, 35, 32, 31, 38, 31, 45, 45, 77, 46, 34, 35, + 40, 33, 45, 64, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 20, 34, 31, 77, 34, 27, 40, 30, 77, 46, 34, + 27, 46, 77, 39, 41, 29, 37, 31, 30, 77, 46, 34, + 31, 39, 64, 77, 27, 40, 30, 77, 46, 34, 31, 77, + 34, 31, 27, 44, 46, 77, 46, 34, 27, 46, 77, 32, + 31, 30, 66, 79, 77, 77, 77, 77, 77, 77, 77, 77, + 1, 40, 30, 77, 41, 40, 77, 46, 34, 31, 77, 42, + 31, 30, 31, 45, 46, 27, 38, 64, 77, 46, 34, 31, + 45, 31, 77, 49, 41, 44, 30, 45, 77, 27, 42, 42, + 31, 27, 44, 65, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 13, 51, 77, 40, 27, 39, 31, 77, 35, 45, 77, + 15, 52, 51, 39, 27, 40, 30, 35, 27, 45, 64, 77, + 11, 35, 40, 33, 77, 41, 32, 77, 11, 35, 40, 33, + 45, 66, 79, 77, 77, 77, 77, 77, 77, 77, 77, 12, + 41, 41, 37, 77, 41, 40, 77, 39, 51, 77, 23, 41, + 44, 37, 45, 64, 77, 51, 31, 77, 13, 35, 33, 34, + 46, 51, 64, 77, 27, 40, 30, 77, 30, 31, 45, 42, + 27, 35, 44, 67, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 14, 41, 46, 34, 35, 40, 33, 77, 28, 31, 45, + 35, 30, 31, 77, 44, 31, 39, 27, 35, 40, 45, 63, + 77, 18, 41, 47, 40, 30, 77, 46, 34, 31, 77, 30, + 31, 29, 27, 51, 79, 77, 77, 77, 77, 77, 77, 77, + 77, 15, 32, 77, 46, 34, 27, 46, 77, 29, 41, 38, + 41, 45, 45, 27, 38, 77, 23, 44, 31, 29, 37, 64, + 77, 28, 41, 47, 40, 30, 38, 31, 45, 45, 77, 27, + 40, 30, 77, 28, 27, 44, 31, 79, 77, 77, 77, 77, + 77, 77, 77, 77, 20, 34, 31, 77, 38, 41, 40, 31, + 77, 27, 40, 30, 77, 38, 31, 48, 31, 38, 77, 45, + 27, 40, 30, 45, 77, 45, 46, 44, 31, 46, 29, 34, + 77, 32, 27, 44, 77, 27, 49, 27, 51, 79, 77, 77, + 77, 77, 77, 77, 77, 77]]), + torch.tensor([[0, 0, 0, 1069, 11, -1, -1, -1, -1]]), + torch.tensor([[0, 0, 0, 1069, 11, -1, -1, -1, -1]]), + ] + # fmt: on + self.assertTrue(torch.allclose(tokens[0], EXPECTED_OUTPUT[0])) + self.assertTrue(torch.allclose(tokens[1], EXPECTED_OUTPUT[1])) + self.assertTrue(torch.allclose(tokens[2], EXPECTED_OUTPUT[2])) diff --git a/utils/check_repo.py b/utils/check_repo.py index 8b02185fa9bd5..632ca7af30ef1 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -51,6 +51,8 @@ "TableTransformerDecoder", # Building part of bigger (tested) model. "TimeSeriesTransformerEncoder", # Building part of bigger (tested) model. "TimeSeriesTransformerDecoder", # Building part of bigger (tested) model. + "JukeboxVQVAE", # Building part of bigger (tested) model. + "JukeboxPrior", # Building part of bigger (tested) model. "DeformableDetrEncoder", # Building part of bigger (tested) model. "DeformableDetrDecoder", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model. @@ -146,6 +148,8 @@ "CLIPSegTextModel", "EsmForProteinFolding", "TimeSeriesTransformerForPrediction", + "JukeboxVQVAE", + "JukeboxPrior", "PegasusXEncoder", "PegasusXDecoder", "PegasusXDecoderWrapper",