mly initialized >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") >>> tokenizer = BertTokenizer.from_pretrained("bert-base-cased") >>> text = "My friends are cool but they eat too many carbs." >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np") >>> encoder_outputs = model.encode(input_ids) >>> decoder_start_token_id = model.config.decoder.bos_token_id >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> logits = outputs.logits ```Nr