| | import click |
| | import os |
| | import onnx |
| | from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector |
| | from sparseml.onnx.utils import ONNXGraph |
| | @click.command() |
| | @click.option('--input-file', help='Path to the input ONNX model file') |
| | @click.option('--output-file', help='Output path for the modified model') |
| | def modify_model(input_file, output_file): |
| | model = onnx.load(input_file, load_external_data=False) |
| | model = KeyValueCacheInjector(model_path=os.path.dirname(input_file)).apply(model) |
| | graph = ONNXGraph(model) |
| | graph.delete_orphaned_node_branches() |
| | onnx.save(model, output_file) |
| | print(f"Modified model saved to: {output_file}") |
| | if __name__ == '__main__': |
| | modify_model() |