pyton >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large') >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') >>> # Generate Summary >>> summary_ids = model.generate(inputs['input_ids']).sequences >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) ``` Mask filling example: ```python >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large") >>> TXT = "My friends are but they eat too many carbs." >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large") >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"] >>> logits = model(input_ids).logits >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) >>> values, predictions = jax.lax.top_k(probs) >>> tokenizer.decode(predictions).split() ``` r