Skip to content

[AIRADSW-567] Fix int8 models qlinearconv#4969

Merged
causten merged 9 commits into
developfrom
fix_int8_models
Jun 18, 2026
Merged

[AIRADSW-567] Fix int8 models qlinearconv#4969
causten merged 9 commits into
developfrom
fix_int8_models

Conversation

@urpetkov-amd

Copy link
Copy Markdown
Collaborator

Motivation

QLinearConv models that have a bias and use per-tensor (scalar) weight quantization fail to parse. The parser throws while building the bias dequantizelinear:

.../migraphx/check_shapes.hpp:220: same_dims: dequantizelinear: Dimensions do not match

This blocks loading common quantized models such as resnet50_int8. The goal of this PR is to make QLinearConv with bias parse correctly for both per-tensor and per-axis (per-channel) weight scales, and to add a regression test for the previously-uncovered per-tensor case.

This is a regression introduced by #4521. Before that change, the int32 bias was added to the convolution result without being dequantized (numerically incorrect, but it parsed). #4521 added the correct bias dequantization (bias_scale = x_scale * w_scale), but only handled the per-axis weight-scale shape, so per-tensor models began throwing at parse time.

Technical Details

When QLinearConv has a bias, the parser dequantizes the int32 bias using a scale computed as x_scale * w_scale:

auto bcast_scale_x = info.add_instruction(
    migraphx::make_op("multibroadcast", {{"out_lens", in_scale_w->get_shape().lens()}}),
    in_scale_x);

auto bias_scale =
    info.add_instruction(migraphx::make_op("mul"), bcast_scale_x, in_scale_w);

auto dquant_bias =
    info.add_instruction(migraphx::make_op("dequantizelinear"), in_b, bias_scale);

bias_scale inherits the shape of the weight scale (in_scale_w):

  • Per-axis weight scalein_scale_w is {out_channels}, so bias_scale is {out_channels} and matches the bias {out_channels}. Works.
  • Per-tensor weight scalein_scale_w is a scalar {1}, so bias_scale is {1} while the bias is {out_channels}.

Fix

Broadcast bias_scale to the bias shape before the bias dequantizelinear:

auto bias_scale =
    info.add_instruction(migraphx::make_op("mul"), bcast_scale_x, in_scale_w);

// dequantizelinear requires the scale to match the bias dims. Broadcast handles
// per-tensor weight scale (scalar); it is a no-op for per-axis weight scale.
auto bcast_bias_scale = info.add_instruction(
    migraphx::make_op("multibroadcast", {{"out_lens", in_b->get_shape().lens()}}),
    bias_scale);

auto dquant_bias = info.add_instruction(
    migraphx::make_op("dequantizelinear"), in_b, bcast_bias_scale);

This is a minimal, single-instruction change confined to the if(args.size() > 8) (bias-present) branch of parse_qlinearconv:

  • Per-tensor: {1}{out_channels} — fixes the crash.
  • Per-axis: {out_channels}{out_channels} — an identity broadcast, so numerics are unchanged and the existing qlinearconv_perchannel_weightbias_test still passes.

Related

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

Comment thread src/onnx/parse_qlinearconv.cpp Outdated
@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #4969   +/-   ##
========================================
  Coverage    92.73%   92.73%           
========================================
  Files          592      592           
  Lines        31289    31289           
========================================
  Hits         29015    29015           
  Misses        2274     2274           
Files with missing lines Coverage Δ
src/onnx/parse_qlinearconv.cpp 74.39% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@urpetkov-amd urpetkov-amd changed the title Fix int8 models qlinearconv [AIRADSW-567] Fix int8 models qlinearconv Jun 16, 2026
@TedThemistokleous TedThemistokleous added bugfix Fixes a bug found in the code. high priority A PR with high priority for review and merging. simple small or simple changes labels Jun 16, 2026

@CharlieL7 CharlieL7 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix looks correct. Please add an ONNX parse test for the per tensor parsing case in test/onnx/parse/qlinearconv_test.cpp

Comment thread src/onnx/parse_qlinearconv.cpp Outdated
@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 16, 2026

Copy link
Copy Markdown
Test Batch New Rate (91abf8) Old Rate (241f7a) Diff Status
torchvision-resnet50 64 3,165.98 3,158.05 0.25%
torchvision-resnet50_fp16 64 5,991.67 6,678.10 -10.28% 🔴
torchvision-densenet121 32 2,670.25 2,705.33 -1.30%
torchvision-densenet121_fp16 32 4,555.74 4,480.31 1.68%
torchvision-inceptionv3 32 680.96 1,772.37 -61.58% 🔴
torchvision-inceptionv3_fp16 32 2,731.77 2,613.31 4.53%
cadene-inceptionv4 16 823.83 444.89 85.17% 🔆
cadene-resnext64x4 16 785.04 412.50 90.32% 🔆
slim-mobilenet 64 6,555.02 8,419.17 -22.14% 🔴
slim-nasnetalarge 64 228.89 nan nan
slim-resnet50v2 64 3,330.18 1,641.82 102.83% 🔆
bert-mrpc-onnx 8 1,171.47 1,165.20 0.54%
bert-mrpc-tf 1 496.54 492.10 0.90%
pytorch-examples-wlang-gru 1 329.56 334.96 -1.61%
pytorch-examples-wlang-lstm 1 451.34 469.54 -3.88%
torchvision-resnet50_1 1 765.74 760.64 0.67%
cadene-dpn92_1 1 452.94 457.65 -1.03%
cadene-resnext101_1 1 363.74 363.68 0.01%
onnx-taau-downsample 1 401.40 401.93 -0.13%
dlrm-criteoterabyte 1 32.70 12.36 164.54% 🔆
dlrm-criteoterabyte_fp16 1 52.61 29.66 77.39% 🔆
agentmodel 1 9,751.34 10,523.13 -7.33% 🔴
unet_fp16 2 57.26 57.23 0.05%
resnet50v1_fp16 1 975.37 959.87 1.62%
resnet50v1_int8 1 935.31 942.92 -0.81%
bert_base_cased_fp16 64 1,102.62 1,092.47 0.93%
bert_large_uncased_fp16 32 347.60 345.19 0.70%
bert_large_fp16 1 204.74 203.23 0.74%
distilgpt2_fp16 16 2,096.30 2,096.77 -0.02%
yolov5s 1 571.79 564.74 1.25%
tinyllama 1 46.01 46.00 0.02%
vicuna-fastchat 1 43.99 44.04 -0.11%
whisper-tiny-encoder 1 419.66 419.55 0.03%
whisper-tiny-decoder 1 414.20 414.85 -0.16%
llama2_7b 1 20.47 20.42 0.23%
qwen1.5-7b 1 23.67 21.75 8.82% 🔆
phi3-3.8b 1 26.84 26.81 0.12%
llama3-8b 1 21.83 21.83 0.00%
whisper-large-encoder 1 10.32 6.50 58.78% 🔆
whisper-large-decoder 1 106.71 106.80 -0.08%
mistral-7b 1 23.86 23.85 0.05%
FLUX.1-schnell 1 754.56 763.88 -1.22%

Regressions detected 🔴

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 16, 2026

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 313, in main
import tensorflow as tf
File "/usr/local/lib/python3.10/dist-packages/tensorflow/init.py", line 38, in
from tensorflow.python.tools import module_util as _module_util
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/init.py", line 36, in
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/pywrap_tensorflow.py", line 26, in
self_check.preload_check()
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/self_check.py", line 63, in preload_check
from tensorflow.python.platform import _pywrap_cpu_feature_guard
ImportError: libamdhip64.so.6: cannot open shared object file: No such file or directory
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
2026-06-17 09:11:18.762790 [WARN] [/data/src/onnx/onnx_parser.cpp:282] Model has unbound symbolic dimension(s): batch_size, encoder_sequence_length, feature_size. These default to 1 and may cause unexpected behavior. Try setting --dim-param @<name> <value> or --input-dim @<input> <dims> if program compilation fails.
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 224, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /data/src/include/migraphx/op/convolution.hpp:113: normalize_compute_shape: CONVOLUTION: mismatched channel numbers: input channels (1) != weights channels (80) * group (1)
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

@urpetkov-amd urpetkov-amd requested a review from CharlieL7 June 17, 2026 11:58
@urpetkov-amd

Copy link
Copy Markdown
Collaborator Author

@CharlieL7

I've added it. Thanks.

@causten causten merged commit 08d5af3 into develop Jun 18, 2026
40 checks passed
@causten causten deleted the fix_int8_models branch June 18, 2026 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bugfix Fixes a bug found in the code. high priority A PR with high priority for review and merging. simple small or simple changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants