ntion used in Mega Args: query (`torch.Tensor` of shape `(target_sequence_length, batch_size, hidden_size)`): The self (or target) sequence input used as query inputs for cross-attention key (`torch.Tensor` of shape `(source_sequence_length, batch_size, hidden_size)`): The cross (or source) sequence input with shape used as keys in cross-attention value (`torch.Tensor` of shape `(source_sequence_length, batch_size, hidden_size)`): The cross (or source) sequence input with shape used as values in cross-attention key_padding_mask (`torch.LongTensor` of shape `(batch_size, source_sequence_length)`, *optional*): Padding mask corresponding to the source sequence, where entries are 1 for *not masked* and 0 for *masked* tokens past_key_values (`tuple(torch.FloatTensor)`, *optional*): If provided, the hidden state returned from the previous timestep during incremental decoding; expects that prior cross-attention keys and values will be the last two items in the tuple output_attentions (`bool`, defaults to `False`): Whether or not to return the cross-attention weights. use_cache (`bool`, defaults to `False`): Whether to perfom incremental decoding; uses `prev_state` as the prior timestep, and returns the updated EMA hidden state for use in the next step Returns: `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and inputs: - **hidden_states** (`torch.FloatTensor` of shape `(target_sequence_length, batch_size, hidden_size)`) -- Hidden states from target sequence updated by gated cross-attention - **attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape `(batch_size, source_sequence_length, target_sequence_length)` -- The pairwise cross-attention weights corresponding to each token in the source and target sequences - **cross_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, source_sequence_length, config.shared_representation_size)` -- The cross-attention key state for use in the next step of incremental decoding - **cross_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, source_sequence_length, config.hidden_size)` -- The cross-attention value state for use in the next step of incremental decoding rÌ