Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to export GLM models to ONNX #35021

Closed
1 of 4 tasks
xenova opened this issue Nov 29, 2024 · 1 comment
Closed
1 of 4 tasks

Unable to export GLM models to ONNX #35021

xenova opened this issue Nov 29, 2024 · 1 comment

Comments

@xenova
Copy link
Contributor

xenova commented Nov 29, 2024

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-6.5.0-1025-azure-x86_64-with-glibc2.31
  • Python version: 3.12.1
  • Huggingface_hub version: 0.26.1
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0+cu124 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the following script:

import os
import torch
from transformers import (
    AutoProcessor,
    GlmForCausalLM,
    DynamicCache,
)

class PatchedGlmForCausalLM(GlmForCausalLM):
    def forward(self, *args):
        input_ids, attention_mask, position_ids, *past_key_values_args = args

        # Convert past_key_values list to DynamicCache
        if len(past_key_values_args) == 0:
            past_key_values = None
        else:
            past_key_values = DynamicCache(self.config.num_hidden_layers)
            for i in range(self.config.num_hidden_layers):
                key = past_key_values_args.pop(0)
                value = past_key_values_args.pop(0)
                past_key_values.update(key_states=key, value_states=value, layer_idx=i)

        o = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )

        flattened_past_key_values_outputs = {
            "logits": o.logits,
        }
        output_past_key_values: DynamicCache = o.past_key_values
        for i, (key, value) in enumerate(
            zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
        ):
            flattened_past_key_values_outputs[f"present.{i}.key"] = key
            flattened_past_key_values_outputs[f"present.{i}.value"] = value

        return flattened_past_key_values_outputs


# Constants
OUTPUT_FOLDER = "output"
TEXT_MODEL_NAME = "model.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")


# Load model and processor
model_id = "hf-internal-testing/tiny-random-GlmForCausalLM"
model = PatchedGlmForCausalLM.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)


# Save model configs and processor
model.config.save_pretrained(OUTPUT_FOLDER)
model.generation_config.save_pretrained(OUTPUT_FOLDER)
processor.save_pretrained(OUTPUT_FOLDER)
os.makedirs(TEMP_MODEL_OUTPUT_FOLDER, exist_ok=True)


# Configuration values
## Text model
text_config = model.config
num_heads = text_config.num_attention_heads
num_key_value_heads = text_config.num_key_value_heads
head_dim = text_config.head_dim
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size


# Dummy input sizes
batch_size = 2
sequence_length = 16
past_sequence_length = 0

## Text inputs
dummy_past_key_values_kwargs = {
    f"past_key_values.{i}.{key}": torch.zeros(
        batch_size,
        num_key_value_heads,
        past_sequence_length,
        head_dim,
        dtype=torch.float32,
    )
    for i in range(num_layers)
    for key in ["key", "value"]
}
input_ids = torch.randint(
    0, text_config.vocab_size,
    (batch_size, sequence_length),
)
attention_mask = torch.ones(batch_size, sequence_length + past_sequence_length, dtype=torch.int64)
position_ids = torch.ones(batch_size, sequence_length, dtype=torch.int64)

text_inputs = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    **dummy_past_key_values_kwargs,
)
text_inputs_positional = tuple(text_inputs.values())
text_outputs = model.forward(*text_inputs_positional)  # Test forward pass


# ONNX Exports
## Text model
TEXT_MODEL_OUTPUT_PATH=os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
    model,
    args=text_inputs_positional,
    f=TEXT_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=list(text_inputs.keys()),
    output_names=["logits"]
    + [f"present.{i}.{key}" for i in range(num_layers) for key in ["key", "value"]],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "total_sequence_length"},
        "position_ids": {0: "batch_size", 1: "sequence_length"},
        **{
            f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
        "logits": {0: "batch_size", 1: "sequence_length"},
        **{
            f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
    },
)

It produces this error:

Traceback (most recent call last):
  File "/workspaces/glm.py", line 110, in <module>
    torch.onnx.export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 639, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 369, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset11.py", line 519, in cat
    return opset9.cat(g, tensor_list, dim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 281, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py", line 575, in cat
    assert all(a)
AssertionError

Expected behavior

The model should export correctly. This may in fact be an ONNX bug, but not 100% sure. Models like Gemma can export correctly, so it seems to be GLM-specific.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants