Enhancing GPTQv2 Format Support in vLLM: Analysis and Implementation
Table of Contents
Deep technical analysis of GPTQv2 format limitations in vLLM, and implementation of CUDA kernel adaptations to enable efficient low-bit/asymmetric quantization inference. Issue (closed): #26343  vLLM, one of the leading LLM inference frameworks, currently lacks robust support for GPTQv2 format (an upgraded version of GPTQv1 format) models, particularly those using low-bit (2/3-bit) or asymmetric quantization. While vLLM doesn’t raise explicit errors when loading such models, it incorrectly treats them as GPTQv1 format, resulting in degraded inference quality and characteristic gibberish outputs (consisting of repeated  This limitation stems from differences in zero point handling between GPTQv1 and GPTQv2 checkpoint formats, which vLLM’s existing GPTQ GeMM kernels don’t account for. This post presents a comprehensive analysis of this limitation, and documents the implementation of kernel adaptations (i.e., in this PR), that enable proper GPTQv2 support while maintaining backward compatibility. Through careful investigation of vLLM’s quantization support and targeted CUDA kernel modifications, I enable robust inference for GPTQv2 format models, especially low-bit or asymmetric ones, with vLLM — contributing a step forward towards efficient LLM deployment. In my case, I use vLLM to serve some low-bit (e.g., 2-bit), asymmetrically quantized models, stored in GPTQv2 format, and encountered this issue. Before diving into technical details, I’ll briefly introduce the background to do so (e.g., why GPTQv2), and some preliminaries (e.g., the checkpoint format) to help follow the technical parts. Key takeaways: Weight quantization that quantizes high-precision model weights (e.g., 16/32-bit) into fewer bits (e.g., 2/3/4/8-bit) has been a common practice in LLM deployment, especially in resource-constrained scenarios. Technically, weight quantization maps the large range of high-precision weights (e.g., within $[-65504, 65504]$ for FP16) into a limited range of quantized weights (e.g., $[0, 2^b - 1]$ for $b$-bit unsigned integer). This mapping typically involves a scaling factor ($scale$) that compresses the range, and a bias ($zero$) that shifts the zero point. We denote $w_o$ as the original weight within the range $[w_{min}, w_{max}]$ (usually for a group of weights), and $w_q$ as the quantized weight. Then, a simple $b$-bit integer quantization is formulated as: \[ scale = \frac{w_{max} - w_{min}}{2^b - 1} \] \[ zero = - \mathrm{round}(\frac{w_{min}}{scale}) \] \[ w_q = \mathrm{clamp}(\mathrm{round}(\frac{w_o}{scale}) + zero) \] To recover the original weight $\hat{w}_o$ during dequantization: \[ \hat{w}_o = (w_q - zero) \cdot scale \] Based on this formulation, quantization methods are categorized by whether the zero point is required: Symmetric quantization assumes $w_{max} = - w_{min}$, so $zero = - \mathrm{round}(\frac{2^b - 1}{2})$ will not change. In this case, $zero$ doesn’t provide additional information given the quantization bits. Asymmetric quantization doesn’t have such assumption. So $zero$ varies across groups of weights, and is necessary for accurately recovering the original weights. Note that most GPTQ implementations are 4-bit symmetric quantization. However, to reduce the quantization error in lower bits, asymmetric quantization is necessary. GPTQ[1] is one of the most popular post-training quantization methods for generative transformers (mainly LLMs and VLMs). It utilizes approximate second-order information (inverse layer Hessian) to reduce quantization errors. Besides, GPTQ could also refer to the specific checkpoint format adopted by GPTQ-quantized models (e.g., by AutoGPTQ), with details explained in this blog. GPTQ is widely supported by the community, including 1) quantization libraries that implement the GPTQ quantization method or support exporting to GPTQ format (although not implementing the quantization method), and 2) kernel libraries and inference frameworks that support inference with models of the GPTQ checkpoint format, as listed below: Quantization libraries: Computing (CUDA) kernels: SOTA LLM inference frameworks, including vLLM, are integrated with the above quantization and kernel libraries, to support efficient LLM deployment with GPTQ quantization. GPTQv2 is an upgraded version of GPTQ (by a different team though), in both the quantization method and checkpoint format. Specifically: Just like GPTQv1, the quantization method and checkpoint format are not coupled. So you can: Note that the conversion between GPTQv2 and GPTQv1 format is irreversible — you can convert GPTQv1 to GPTQv2 losslessly, but not from GPTQv2 to GPTQv1. This is due to the “-1” issue of GPTQv1 as mentioned above. In this way, the actual zero point range in suppressed by clamping $0 - 1$ to $0$ in GPTQv1. For example, in INT2 quantization, the effective range shrinks from $[0,3]$ to $[1,3]$. Therefore, GPTQv2 format is a preferable choice in asymmetric quantization, especially for low-bit quantization. After all the preparation, we can finally dive into the technical details about why and how vLLM fails for GPTQv2 in my case — even though it has some sort of support actually, which I found after careful investigation. Key takeaways: The first step is to understand how vLLM routes computing kernels for different quantizations, like GPTQ. This is implemented in the model execution part of the LLMEngine. For simplicity, we only consider dense models (no MoE). It includes the following calling hierarchy: 1. Model-level quantization configuration: 2. Layer-level quantization configuration (linear methods): 3. (Optional) Kernel selection (linear kernels): vLLM supports several ways of routing to low-level CUDA kernels from a linear method class: vLLM integrates several optimized kernels for GPTQ format models, as listed in GPTQ: Quantization and Checkpoint Format, including Marlin, ExLlamaV2, BitBLAS, etc. vLLM also has fallback kernels for unsupported quantization configurations of these kernels. Following the analysis in vLLM’s Kernel Routing Hierarchy, I summarize this support matrix (by linear methods): Notes: As a result, neither 2/3-bit nor asymmetric quantization in GPTQv2 format are unsupported by vLLM, which motivates this PR. Based on the above analysis, vLLM’s GPTQ linear methods lack support for 2/3-bit quantization and asymmetric quantization in GPTQv2 format, and require adaption to robustly support GPTQv2 format models of various configurations. To add such support, at least one linear method should be added/adapted. To adapt existing linear methods: So, the plan is to adapt  During this linear method & kernel adaption, there are 3 points to keep in mind: In response to Pt. 1 and 2: When testing, I found that the original  To ensure Pt. 3, review vLLM’s Support for GPTQ(v2) Format — both  TODO: Add some performance benchmarks. Currently I’ve found that 2-bit gptq_gemm is slower during decoding (GeMV) than prefilling (GeMM). This post details the development of GPTQv2 format support in vLLM, which addresses a significant gap in low-bit asymmetric quantization inference with SOTA LLM inference frameworks. Questions and discussions are welcomed. Possible future works: Frantar, Elias, et al. “GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers.” arXiv preprint arXiv:2210.17323 (2022). ↩ Li, Yuhang, et al. “GPTAQ: Efficient Finetuning-Free Quantization for Asymmetric Calibration.” arXiv preprint arXiv:2504.02692 (2025). ↩ de Kok, Daniël. “GPTQ Checkpoint Format.” Daniël’s Website, 7 Aug. 2024, danieldk.eu/GPTQ-Checkpoint-Format. ↩
 Pull request (merged): #26092 
 Commit: 5cc6bddIntroduction
!!!, details in this issue).Background and Preliminaries
Weight Quantization of LLMs
GPTQ: Quantization and Checkpoint Format
From GPTQv1 to GPTQv2
v2=True).format="gptq_v2").format="gptq_v2" only, in GPTQModel.v2=True only, in GPTQModel.Cause Analysis
    graph TD
    A[VllmConfig] --> B[ModelConfig._verify_quantization]
    B --> |"Priority: gptq_marlin > gptq_bitblas > gptq"| E[QuantizationConfig.get_quant_method]
    
    E -->|4/8-bit + sym| J[GPTQMarlinLinearMethod]
    E -->|4/8-bit + sym| K[GPTQBitBLASLinearMethod]
    E -->|2/3/4/8-bit + sym/asym| L[GPTQLinearMethod]
    
    J --> JJ[MarlinLinearKernel]
    K --> KK[BitBLASLinearKernel]
    L --> LL[Direct Kernel Call]
    
    JJ --> M[gptq_marlin_gemm
CUDA kernel]
    KK --> N[bitblas.Matmul
External library]
    LL --> O[gptq_gemm
CUDA kernel]
    
    M --> Q["✅ Marlin: gptq/gptq_v2"]
    N --> R["✅ BitBLAS: gptq/gptq_v2"] 
    O --> S["❌ GPTQ: gptq only"]
    
    style J fill:#90EE90
    style K fill:#90EE90
    style L fill:#FFB6C1
    style Q fill:#90EE90
    style R fill:#90EE90
    style S fill:#FFB6C1
vLLM’s Kernel Routing Hierarchy
vllm/model_executor/models), vllm_config: VllmConfig is passed to the model at initialization, which contains quant_config: QuantizationConfig. QuantizationConfig with quantization-specific overrides (e.g., GPTQConfig). See all quantizations.quant_config, ModelConfig._verify_quantization will select one from a priority list (see below).quant_config.quant_method = quant_config.get_quant_method. get_quant_method returns a specific linear method class (inherited from LinearMethodBase, e.g., GPTQLinearMethod) depending on the quantization configuration.LinearMethodBase.apply (e.g., GPTQLinearMethod.apply calls gptq_gemm).GPTQLinearMethod directly routes to gptq_gemm, a registered custom operand of vLLM (implemented in csrc/quantization/gptq).MPLinearKernel class as an interface for routing to this kernel (in vllm/model_executor/layers/quantization/kernels/mixed_precision). For example, GPTQBitBLASLinearMethod routes to BitBLASLinearKernel, and GPTQMarlinLinearMethod calls choose_mp_linear_kernel for flexible routing.BitBLASLinearKernel calls Matmul from the bitblas Python library.vLLM’s Support for GPTQ(v2) Format
Method Bits Sym GPTQ Format GPTQMarlin 4,8 True gptq, gptq_v2GPTQBitBLAS 4,8 True gptq, gptq_v2GPTQ 2,3,4,8 Any gptqModelConfig._verify_quantization): = 
GPTQLinearMethod supports 2/3-bit quantization and asymmetric quantization. However, it lacks GPTQv2 support (both the other two supports).Solution
Approach: Adapt GPTQ Linear Method & Kernel
GPTQMarlinLinearMethod (lacking 2/3-bit and asymmetric support): It requires also modifying vLLM’s Marlin CUDA kernel, which is dedicated for 4/8-bit symmetric quantization — not a good choice.GPTQBitBLASLinearMethod (lacking 2/3-bit and asymmetric support): It requires modifying only the linear method/kernel (Python code), since the bitblas library itself supports the bits and sym we want — a reasonable choice, but requires the optional bitblas package to be installed.GPTQLinearMethod (lacking GPTQv2 format support): It requires also modifying vLLM’s gptq_gemm CUDA kernel, by only adapting the zero point handling logic — preferred.GPTQLinearMethod (with GPTQConfig) and gptq_gemm to add proper GPTQv2 format support.Details of Adaption: Conditioned on Format
use_v2_format: bool attribute to GPTQLinearMethod that indicates whether checkpoint_format == "gptq_v2".bool use_v2_format argument to gptq_gemm, which accepts GPTQLinearMethod.use_v2_format as input.gptq_gemm, update the zero point handling logic to be conditioned on use_v2_format. For example:// In `reconstruct_exllama_2bit_kernel`:
// Previous: zeros[i] + 1 (hardcoded for GPTQv1)
;
;
;
;
// Now: zeros[i] + offset (conditioned on `use_v2_format`)
int zero_offset = use_v2_format ? 0 : 1;
...
;
;
;
;
gptq_gemm is buggy at 4-bit even with symmetrically quantized model of GPTQv1 format — out of scope of this PR.GPTQMarlinLinearMethod and GPTQBitBLASLinearMethod are not affected, as they already support GPTQv2 format, though limited to 4/8-bit symmetric quantization.Conclusion
gptq_gemm.gptq_gemm.