# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2 checkpoint."""


import argparse
import os
from functools import reduce

import fairseq
import torch
from datasets import load_dataset

from transformers import Wav2Vec2Processor, logging
from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig

# Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy  # noqa: F401
from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

MAPPING = {
    "post_extract_proj": "feature_projection.projection",
    "models.0.layer_norm": "feature_projection.layer_norm",
    "self_attn.k_proj": "encoder.layers.*.attention.k_proj",
    "self_attn.v_proj": "encoder.layers.*.attention.v_proj",
    "self_attn.q_proj": "encoder.layers.*.attention.q_proj",
    "self_attn.out_proj": "encoder.layers.*.attention.out_proj",
    "self_attn_layer_norm": "encoder.layers.*.layer_norm",
    "fc1": "encoder.layers.*.feed_forward.intermediate_dense",
    "fc2": "encoder.layers.*.feed_forward.output_dense",
    "final_layer_norm": "encoder.layers.*.final_layer_norm",
    "encoder.layer_norm": "encoder.layer_norm",
    "w2v_model.layer_norm": "feature_projection.layer_norm",
    "w2v_encoder.proj": "lm_head",
    "mask_emb": "masked_spec_embed",
}
TOP_LEVEL_KEYS = [
    "lm_head",
]


def set_recursively(hf_pointer, key, value, full_name, weight_type):
    for attribute in key.split("."):
        hf_pointer = getattr(hf_pointer, attribute)

    if weight_type is not None:
        hf_shape = getattr(hf_pointer, weight_type).shape
    else:
        hf_shape = hf_pointer.shape

    if hf_shape != value.shape:
        raise ValueError(
            f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
            f" {value.shape} for {full_name}"
        )

    if weight_type == "weight":
        hf_pointer.weight.data = value
    elif weight_type == "weight_g":
        hf_pointer.weight_g.data = value
    elif weight_type == "weight_v":
        hf_pointer.weight_v.data = value
    elif weight_type == "bias":
        hf_pointer.bias.data = value
    else:
        hf_pointer.data = value

    logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")


def recursively_load_weights(fairseq_model, hf_model, is_headless):
    unused_weights = []
    fairseq_dict = fairseq_model.state_dict()

    if not is_headless:
        feature_extractor = hf_model.data2vec_audio.feature_extractor
        pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed

    else:
        feature_extractor = hf_model.feature_extractor
        pos_conv_embedding = hf_model.encoder.pos_conv_embed

    for name, value in fairseq_dict.items():
        is_used = False
        if "conv_layers" in name:
            load_conv_layer(
                name,
                value,
                feature_extractor,
                unused_weights,
            )
            is_used = True
        elif "pos_conv" in name:
            load_pos_conv_layer(
                name,
                value,
                pos_conv_embedding,
                unused_weights,
            )
            is_used = True
        else:
            for key, mapped_key in MAPPING.items():
                if not is_headless:
                    mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
                if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
                    is_used = True
                    if "*" in mapped_key:
                        layer_index = name.split(key)[0].split(".")[-2]
                        mapped_key = mapped_key.replace("*", layer_index)
                    if "weight_g" in name:
                        weight_type = "weight_g"
                    elif "weight_v" in name:
                        weight_type = "weight_v"
                    elif "bias" in name:
                        weight_type = "bias"
                    elif "weight" in name:
                        # TODO: don't match quantizer.weight_proj
                        weight_type = "weight"
                    else:
                        weight_type = None
                    set_recursively(hf_model, mapped_key, value, name, weight_type)
                continue
        if not is_used:
            unused_weights.append(name)

    logger.warning(f"Unused weights: {unused_weights}")


def access_by_string(module, path):
    names = path.split(".")
    return reduce(getattr, names, module)


def set_weights(full_name, module, fsq_value, hf_weight_path):
    hf_weight = access_by_string(module, hf_weight_path)
    hf_value = hf_weight.data

    if fsq_value.shape != hf_value.shape:
        raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.")
    hf_weight.data = fsq_value
    logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.")


def load_conv_layer(full_name, value, feature_extractor, unused_weights):
    name = full_name.split("conv_layers.")[-1]
    items = name.split(".")
    layer_id = int(items[0])
    type_id = int(items[1])

    weight_type = name.split(".")[-1]
    if type_id == 0:
        layer_type = "conv"
    elif type_id == 2:
        layer_type = "layer_norm"
    else:
        unused_weights.append(full_name)
        return

    set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}")


def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights):
    name = full_name.split("pos_conv.")[-1]
    items = name.split(".")
    layer_id = int(items[0])
    type_id = int(items[1])

    weight_type = name.split(".")[-1]
    if type_id != 0:
        unused_weights.append(full_name)
        return
    else:
        layer_type = "conv"

    set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}")


@torch.no_grad()
def convert_wav2vec2_checkpoint(
    checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    if config_path is not None:
        config = Data2VecAudioConfig.from_pretrained(config_path)
    else:
        config = Data2VecAudioConfig()

    if not is_finetuned:
        # Modify final_proj layer name
        hf_wav2vec = Data2VecAudioModel(config)
        data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)

        state_dict = torch.load(checkpoint_path)
        state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
        state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
        converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
        torch.save(state_dict, converted_ckpt)
    else:
        hf_wav2vec = Data2VecAudioForCTC(config)
        converted_ckpt = checkpoint_path

    def load_data2vec(path):
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path])
        return model[0].eval()

    model = load_data2vec(converted_ckpt)

    recursively_load_weights(model, hf_wav2vec, not is_finetuned)

    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")

    ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
    input_audio = [x["array"] for x in ds[:4]["audio"]]

    inputs = processor(input_audio, return_tensors="pt", padding=True)

    input_values = inputs.input_values
    attention_mask = inputs.attention_mask
    #    input_values = inputs.input_values[:, :-1]
    #    attention_mask = inputs.attention_mask[:, :-1]

    hf_wav2vec.eval()
    model.eval()
    if is_finetuned:
        their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
            "encoder_out"
        ].transpose(0, 1)
        our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"]

        pred_ids = torch.argmax(our_output, dim=-1)
        output_string = processor.batch_decode(pred_ids)

        print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}")
    else:
        their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
            "layer_results"
        ][-1][0].transpose(0, 1)
        our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"]

    print(our_output.shape, their_output.shape)
    max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
    print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
    success = torch.allclose(our_output, their_output, atol=1e-3)
    print("Do both models output the same tensors?", "🔥" if success else "💩")
    if not success:
        raise Exception("Something went wRoNg")

    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)

    if is_finetuned:
        processor.save_pretrained(pytorch_dump_folder_path)
    else:
        processor.feature_extractor.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
    parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    parser.add_argument(
        "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
    )
    args = parser.parse_args()
    convert_wav2vec2_checkpoint(
        args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
    )
