# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- # This script converts Longformer model from huggingface transformers 4.0 or later to ONNX. # It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime. # # Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed. # Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need # update this script with correct name of longformer_attention.cpython-*.so (search TODO below). # # It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0. # Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking. # # Example commands to export longformer base model in Linux: # conda create -n longformer python=3.8 # conda activate longformer # python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html # python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0 # python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu # cd ./torch_extensions # rm -rf build # python setup.py install # cd .. # python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx # python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv # # GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package. # # For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version. import argparse import inspect import os import sys from pathlib import Path import torch import transformers from longformer_helper import PRETRAINED_LONGFORMER_MODELS from onnx import load_model from packaging import version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import parse_args from transformers import LongformerModel, LongformerSelfAttention sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from onnx_model_bert import BertOnnxModel # noqa: E402 from torch_onnx_export_helper import torch_onnx_export # noqa: E402 # Supports format 0 or 1 weight_bias_format = 0 @parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i") def my_longformer_attention( g, input, weight, bias, mask, global_weight, global_bias, global_mask, num_heads, window, ): return g.op( "com.microsoft::LongformerAttention", input, weight, bias, mask, global_weight, global_bias, global_mask, num_heads_i=num_heads, window_i=window, ) # namespace is onnxruntime which is registered in longformer_attention.cpp register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9) # TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8 torch.ops.load_library( r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so" ) def parse_arguments(): """Parse arguments Returns: args: Namespace """ parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model", required=False, type=str, default="longformer-base-4096", help="Checkpoint directory or pre-trained model names in the list: " + ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()), ) parser.add_argument( "--export_padding", required=False, action="store_true", help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.", ) parser.set_defaults(export_padding=False) parser.add_argument( "--no_merge_qkv", required=False, action="store_true", help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.", ) parser.set_defaults(no_merge_qkv=False) parser.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model.", ) parser.set_defaults(optimize_onnx=False) parser.add_argument( "-p", "--precision", required=False, type=str, default="fp32", choices=["fp32", "fp16"], help="Precision of model to run: fp32 for full precision, fp16 for mixed precision", ) args = parser.parse_args() return args # Create a dummy input for ONNX export. def get_dummy_inputs(config, export_padding, device): # When sequence length is multiple of windows size, there is no padding logic in ONNX graph sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0] # Create dummy inputs input_ids = torch.arange(sequence_length).unsqueeze(0).to(device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device) attention_mask[:, sequence_length - 1] = 0 # last token is masked global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device) global_attention_mask[:, 0] = 1 # first token is global token return input_ids, attention_mask, global_attention_mask # A new function to replace LongformerSelfAttention.forward # For transformers 4.0.0 def my_longformer_self_attention_forward_4( self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, ): global_mask = is_index_global_attn.int() # The following check is based on the dummy inputs (only the first token is global). assert ( len(global_mask.shape) == 2 and global_mask.shape[0] == 1 and global_mask.count_nonzero().item() == 1 and global_mask.tolist()[0][0] == 1 ) input_mask = is_index_masked.float() # TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation. input_mask = input_mask.masked_fill(is_index_masked, -10000.0) # Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0) # TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0 # TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask # The following check is based on the dummy inputs (only the last token is masked). assert ( len(input_mask.shape) == 2 and input_mask.shape[0] == 1 and input_mask.count_nonzero().item() == 1 and input_mask.tolist()[0][-1] == -10000.0 ) weight = torch.stack( ( self.query.weight.transpose(0, 1), self.key.weight.transpose(0, 1), self.value.weight.transpose(0, 1), ), dim=weight_bias_format, ) if weight_bias_format == 1: # shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default weight = weight.reshape(self.embed_dim, 3 * self.embed_dim) global_weight = torch.stack( ( self.query_global.weight.transpose(0, 1), self.key_global.weight.transpose(0, 1), self.value_global.weight.transpose(0, 1), ), dim=weight_bias_format, ) if weight_bias_format == 1: global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim) if weight_bias_format == 1: bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0) bias = bias.reshape(3 * self.embed_dim) global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0) global_bias = global_bias.reshape(3 * self.embed_dim) else: bias = torch.stack( (self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0 ) bias = bias.reshape(5 * self.embed_dim) global_bias = self.query_global.bias global_bias = global_bias.reshape(1 * self.embed_dim) attn_output = torch.ops.onnxruntime.LongformerAttention( hidden_states, weight, bias, input_mask, global_weight, global_bias, global_mask, self.num_heads, self.one_sided_attn_window_size, ) assert attn_output.size() == hidden_states.size(), "Unexpected size" outputs = (attn_output,) return outputs # For transformers 4.3.0 def my_longformer_self_attention_forward_4_3( self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, output_attentions=False, ): assert output_attentions is False return my_longformer_self_attention_forward_4( self, hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, ) # For transformers 4.3.2 or later versions def my_longformer_self_attention_forward_4_3_2( self, hidden_states, attention_mask=None, layer_head_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None, output_attentions=False, ): assert output_attentions is False assert layer_head_mask is None return my_longformer_self_attention_forward_4( self, hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, ) def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool): """Export longformer model to ONNX Args: model (LongformerModel): longformer model onnx_model_path (str): output onnx path export_padding (bool): whether export padding logic to ONNX so that input string can be any length. Raises: RuntimeError: This tool requires transformers 4.0.0 or later. RuntimeError: LongformerSelfAttention.forward arguments are different. """ input_ids, attention_mask, global_attention_mask = get_dummy_inputs( model.config, export_padding, device=torch.device("cpu") ) _ = model( input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, ) if version.parse(transformers.__version__) < version.parse("4.0.0"): raise RuntimeError("This tool requires transformers 4.0.0 or later.") # Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args) args_to_func = { "self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2, "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3, "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4, } if key not in args_to_func: print( "Current arguments", inspect.getfullargspec(LongformerSelfAttention.forward).args, ) raise RuntimeError( "LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)." ) # Store for restoring later original_forward = LongformerSelfAttention.forward LongformerSelfAttention.forward = args_to_func[key] example_inputs = (input_ids, attention_mask, global_attention_mask) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( model, example_inputs, onnx_model_path, opset_version=12, input_names=["input_ids", "attention_mask", "global_attention_mask"], output_names=["last_state", "pooler"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "global_attention_mask": {0: "batch_size", 1: "sequence_length"}, "last_state": {0: "batch_size", 1: "sequence_length"}, "pooler": {0: "batch_size", 1: "sequence_length"}, }, custom_opsets={"com.microsoft": 1}, ) print(f"ONNX model exported to {onnx_model_path}") # Restore original implementation: LongformerSelfAttention.forward = original_forward def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None): """Optimize longformer onnx model Args: onnx_model_path (str): path of original ONNX model. fp32_model_path (str): path of optimized fp32 model. fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None. """ model = load_model(onnx_model_path, format=None, load_external_data=True) optimizer = BertOnnxModel(model) optimizer.optimize() use_external_data_format = False if fp32_model_path: optimizer.save_model_to_file(fp32_model_path, use_external_data_format) print(f"optimized fp32 model saved to {fp32_model_path}") if fp16_model_path: optimizer.convert_float_to_float16(keep_io_types=True) optimizer.save_model_to_file(fp16_model_path, use_external_data_format) print(f"optimized fp16 model saved to {fp16_model_path}") def main(args): model_name = args.model onnx_model_path = model_name + ".onnx" global weight_bias_format # noqa: PLW0603 weight_bias_format = 0 if args.no_merge_qkv else 1 model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name]) export_longformer(model, onnx_model_path, args.export_padding) if args.optimize_onnx or args.precision != "fp32": fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx" fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path) if __name__ == "__main__": args = parse_arguments() main(args)