# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import re

import requests
import torch

# git clone https://github.com/salesforce/BLIP.git
from models.blip import blip_decoder
from models.blip_itm import blip_itm
from models.blip_vqa import blip_vqa
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from transformers import (
    BertTokenizer,
    BlipConfig,
    BlipForConditionalGeneration,
    BlipForImageTextRetrieval,
    BlipForQuestionAnswering,
)


def load_demo_image(image_size, device):
    img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ]
    )
    image = transform(raw_image).unsqueeze(0).to(device)
    return image


def rename_key(key):
    if "visual_encoder" in key:
        key = re.sub("visual_encoder*", "vision_model.encoder", key)
    if "blocks" in key:
        key = re.sub(r"blocks", "layers", key)
    if "attn" in key:
        key = re.sub(r"attn", "self_attn", key)
    if "norm1" in key:
        key = re.sub(r"norm1", "layer_norm1", key)
    if "norm2" in key:
        key = re.sub(r"norm2", "layer_norm2", key)
    if "encoder.norm" in key:
        key = re.sub(r"encoder.norm", "post_layernorm", key)
    if "encoder.patch_embed.proj" in key:
        key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)

    if "encoder.pos_embed" in key:
        key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
    if "encoder.cls_token" in key:
        key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)

    if "self_attn" in key:
        key = re.sub(r"self_attn.proj", "self_attn.projection", key)

    return key


@torch.no_grad()
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    if config_path is not None:
        config = BlipConfig.from_pretrained(config_path)
    else:
        config = BlipConfig(projection_dim=512, text_config={}, vision_config={})

    hf_model = BlipForConditionalGeneration(config).eval()

    model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"

    pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
    pt_model = pt_model.eval()

    modified_state_dict = pt_model.state_dict()
    for key in modified_state_dict.copy():
        value = modified_state_dict.pop(key)
        renamed_key = rename_key(key)
        modified_state_dict[renamed_key] = value

    hf_model.load_state_dict(modified_state_dict)

    image_size = 384
    image = load_demo_image(image_size=image_size, device="cpu")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    input_ids = tokenizer(["a picture of"]).input_ids

    out = hf_model.generate(image, input_ids)

    assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]

    out = hf_model.generate(image)

    assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]

    if pytorch_dump_folder_path is not None:
        hf_model.save_pretrained(pytorch_dump_folder_path)

    # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
    model_url = (
        "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
    )

    vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
    vqa_model.eval()

    modified_state_dict = vqa_model.state_dict()
    for key in modified_state_dict.copy():
        value = modified_state_dict.pop(key)
        renamed_key = rename_key(key)
        modified_state_dict[renamed_key] = value

    hf_vqa_model = BlipForQuestionAnswering(config)

    hf_vqa_model.load_state_dict(modified_state_dict)

    question = ["How many dogs are in this image?"]
    question_input_ids = tokenizer(question, return_tensors="pt").input_ids

    answer = hf_vqa_model.generate(question_input_ids, image)
    print(tokenizer.decode(answer[0]))

    assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
    if pytorch_dump_folder_path is not None:
        hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")

    model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"

    itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
    itm_model.eval()

    modified_state_dict = itm_model.state_dict()
    for key in modified_state_dict.copy():
        value = modified_state_dict.pop(key)
        renamed_key = rename_key(key)
        modified_state_dict[renamed_key] = value

    hf_itm_model = BlipForImageTextRetrieval(config)

    question = ["A picture of a woman with a dog sitting in a beach"]
    question_input_ids = tokenizer(
        question,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=35,
    ).input_ids

    hf_itm_model.load_state_dict(modified_state_dict)
    hf_itm_model.eval()

    out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
    out = hf_itm_model(question_input_ids, image, use_itm_head=False)

    assert out[0].item() == 0.2110687494277954
    assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127

    if pytorch_dump_folder_path is not None:
        hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    args = parser.parse_args()

    convert_blip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
