value vectors from global tokens local_q (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: query vectors from local tokens local_k (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: key vectors from local tokens local_v (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]: value vectors from local tokens mask (`torch.FloatTensor`) of shape [batch_size, padded_seq_len]: attention mask dim (DimensionInfo): DimensionInfo wrapper for dimensions Returns: output of shape `[batch_sizes, length, features]`. where length will be padded to a multiple of block_size r