Skip to content

Latest commit

History

History
112 lines (79 loc) 路 5.4 KB

plbart.mdx

File metadata and controls

112 lines (79 loc) 路 5.4 KB

PLBart

DISCLAIMER: If you see something strange, file a Github Issue and assign @gchhablani.

Overview of PLBart

The PLBART model was proposed in Unified Pre-training for Program Understanding and Generation by Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang. This is a BART-like model which can be used to perform code-summarization, code-generation, and code-translation tasks. The pre-trained model plbart-base has been trained using multilingual denoising task on Java, Python and English.

According to the abstract

Code summarization and generation empower conversion between programming language (PL) and natural language (NL), while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks. PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding. Experiments on code summarization in the English language, code generation, and code translation in seven programming languages show that PLBART outperforms or rivals state-of-the-art models. Moreover, experiments on discriminative tasks, e.g., program repair, clone detection, and vulnerable code detection, demonstrate PLBART's effectiveness in program understanding. Furthermore, analysis reveals that PLBART learns program syntax, style (e.g., identifier naming convention), logical flow (e.g., if block inside an else block is equivalent to else if block) that are crucial to program semantics and thus excels even with limited annotations.

This model was contributed by gchhablani. The Authors' code can be found here.

Training of PLBart

PLBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for code-to-text, text-to-code, code-to-code tasks. As the model is multilingual it expects the sequences in a different format. A special language id token is added in both the source and target text. The source text format is X [eos, src_lang_code] where X is the source text. The target text format is [tgt_lang_code] X [eos]. bos is never used.

However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to the paper to learn more about this.

In cases where the language code is needed, The regular [~PLBartTokenizer.__call__] will encode source text format, and it should be wrapped inside the context manager [~PLBartTokenizer.as_target_tokenizer] to encode target text format.

  • Supervised training
>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer

>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-base", src_lang="en_XX", tgt_lang="python")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> expected_translation_english = "Returns the maximum value of a b c."
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
...     labels = tokenizer(expected_translation_english, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
>>> # forward pass
>>> model(**inputs)
  • Generation

    While generating the target text set the decoder_start_token_id to the target language id. The following example shows how to translate Python to English using the uclanlp/plbart-python-en_XX model.

>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer

>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
>>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-python-en_XX")
>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"])
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
"Returns the maximum value of a b c."

PLBartConfig

[[autodoc]] PLBartConfig

PLBartTokenizer

[[autodoc]] PLBartTokenizer - as_target_tokenizer - build_inputs_with_special_tokens

PLBartModel

[[autodoc]] PLBartModel - forward

PLBartForConditionalGeneration

[[autodoc]] PLBartForConditionalGeneration - forward

PLBartForSequenceClassification

[[autodoc]] PLBartForSequenceClassification - forward

PLBartForCausalLM

[[autodoc]] PLBartForCausalLM - forward