ch.Tensor` of shape `(sequence_length, batch_size, hidden_size)`): Hidden states to be updated by Mega's self-attention padding_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 for *masked* causal_mask (`torch.LongTensor` of shape `(sequence_length, sequence_length)`, *optional*): Indicates which inputs are to be ignored due to causal attention, where elements are either 1 for *not masked* or 0 for *masked* past_key_values (`tuple(torch.Tensor)`, *optional*): The hidden states returned from the previous timestep during incremental decoding; expects that self-attention key, value, and EMA states are the first 3 entries in the tuple output_attentions (`bool`, default `False`): Whether to return self-attention weights use_cache (`bool`, default `False`): Whether to perfom incremental decoding; uses `past_key_values` as prior state, and returns the updated states for use in the next step Returns: `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and inputs: - **hidden_states** (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`) -- Hidden states from target sequence updated by Mega's self-attention - **attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape `(batch_size, 1, sequence_length, sequence_length)` -- The self-attention weights corresponding to how each token in the input sequence attends to every other token - **self_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, sequence_length, config.shared_representation_size)` -- The self-attention key state for use in the next step of incremental decoding - **self_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, sequence_length, config.hidden_size)` -- The self-attention value state for use in the next step of incremental decoding - **self_ema_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, config.ndim)` The incremental EMA state for use in the next step of incremental decoding. z$Input embedding dimension should be z