[ML] bitsandbytes NF4 quantize, dequantize analysis

Youngrok Song
4 min readSep 16, 2023

--

Analyzing bitsandbytes quantize & dequantize process with real examples

  • model: polyglot-ko-5.8b (GPTNeoX)

1. bitsandbytes NF4

Information based on the paper:

NF data type builds on Quantile Quantization

[Quantile Quantization]

  • information-theoretically optimal data type
  • ensures each quantization bin has equal number of values assigned
  • estimates the quantile through empirical cumulative distribution function

[Limitations of QQ]

  • fast quantile approximation algorithms such as SRAM quantiles are used to estimate them
  • due to approximate nature → has large quantization error

[Premise of NF4]

  • Expensive quantile estimates can be avoided when input tensors come from distribution fixed up to a quantization constant
  • Pretrained NN weights usually have zero-centered normal distribution
  • can transform all weights to single fixed distribution
  • proved with Shapiro-Wilk test in the paper (about 7.5% of llama is non-normally distributed)

→ NF4 sets arbitrary range [-1,1]

2. Analyzing with GPTNeoX based model

2–1. Loading the model in NF4

Loading the model with

  • NF4 quantization (no double quant)
  • bfloat16 computation
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

gpu_num = 0
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/polyglot-ko-5.8b",
device_map = {"": "cuda:" + str(gpu_num)},
quantization_config=bnb_config
)

Example of the loaded model weights

print(model.gpt_neox.layers[0].attention.query_key_value.weight.dtype)
# torch.uint8
print(model.gpt_neox.layers[0].attention.query_key_value.weight)
'''
Parameter containing:
Parameter(Params4bit([[103],
[184],
[110],
...,
[ 9],
[ 58],
[125]], device='cuda:0', dtype=torch.uint8))
'''

The model is loaded in uint8 data format

2–2. Analyzing Sample query_key_value weight

I’ll be dequantizing the follwoing “query_key_value” weight

sample_layer_weight = model.gpt_neox.layers[0].attention.query_key_value.weight
print("SAMPLE LAYER", sample_layer_weight.shape, sample_layer_weight.dtype)
# SAMPLE LAYER torch.Size([25165824, 1]) torch.uint8
print(sample_layer_weight.weight)
'''
Parameter containing:
Parameter(Params4bit([[103],
[184],
[110],
...,
[ 9],
[ 58],
[125]], device='cuda:0', dtype=torch.uint8))
'''

The shape of the quantized layer is [25165824, 1]. It represents the following

  • dtype is torch.uint8 → each tensor is packing two 4-bit values
  • therefore there are 2*25165824 → 50,331,648 number of tensors
  • this can be represented as 3 * 4096 * 4096
  • since GPTNeoX’s query_key_value weight has shape (3*hidden_dim, hidden_dim)
  • this means the values in the weight is quantized & squeezed into a 1d tensor

Lets see how the values are packed into a 8-bit uint tensor

print([v[0] for v in sample_layer_weight[:2].cpu().numpy().tolist()])
# [103, 184]
print(format(sample_layer_weight[0].item(), '08b'))
# 01100111
print(format(sample_layer_weight[1].item(), '08b'))
# 10111000

The first uint tensor “01100111” can be split as “0110” (6) and “0111” (7).

2–3. Dequantizing Sample query_key_value weight

To dequantize the tensor, we need the “quant_state” which contains the following

quant_state = sample_layer_weight.quant_state
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
print("absmax",absmax,absmax.shape)
# absmax tensor([0.0492, 0.0487, 0.0561, ..., 0.0257, 0.0217, 0.0180], device='cuda:0') torch.Size([786432])

print("shape",shape) # (3*4096,4096) -> original shape
print("dtype",dtype) # dtype torch.float16
print("blocksize",blocksize) # 64
print("quant_type",quant_type) # quant_type nf4
print("data_type",data_type, len(data_type))
'''
data_type tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0000,
0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0000],
device='cuda:0') 16
'''

shape, dtype are information about the dequantized tensor

the shape of “absmax” can be interpreted as follows:

  • absmax represents the absolute maximum values of each block
  • absmax size * blocksize = 2 * quantized tensor dimension
  • 64 * 786432 = 50331648 = 2* 25165824

data_type represents the threshold for each quantized value bins

  • ex. if scaled original tensor is -0.5251< ≤-0.3949
  • → it belongs in bin 3

sample code of binning:

if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
...

I’ll be dequantizing the weight using bitsandbytes’ function

import bitsandbytes.functional as F
dequantized = F.dequantize_4bit(sample_layer_weight, quant_state).to(torch.bfloat16)
print(dequantized.shape) # (3*4096, 4096) - (3*hidden_size, hidden_size)
# torch.Size([12288, 4096])
print(dequantized[0])
'''
tensor([-0.0045, 0.0000, 0.0166, ..., 0.0703, 0.0239, 0.0239],
device='cuda:0', dtype=torch.bfloat16)
'''

3. Reproducing the Quantize, Dequantize process

I’ll be requantizing → dequantizing with the dequantized weight from 2–3

The functions implement the block-quantization equation from the paper:

Since NF4 uses range [-1, 1] 127 is replaced with 1

Quantizing the first, second values: [-0.0045, 0.0000]:

  • they are quantized as a uint8 tensor of 01100111
  • 0110: 6, 0111: 7

First the weight block that contains the weights and the absolute max value of the block are as follows:

def get_absmax(x):
return max(abs(x))

weight_block = dequantized[0][:blocksize].clone().cpu()
absmax = get_absmax(weight_block)
print("absmax of block:", absmax)
# absmax of block: tensor(0.0491, dtype=torch.bfloat16)

quantizing the weights:

get_quantile uses argmin of absolute difference below since it was already restored from being quantized (just for simplicity)

  • quantile 0: ≤ -0.6962
  • quantile 1: -0.6962 < ≤ -0.5251
  • ..
  • quantile 6: -0.1848< ≤ -0.0911
def get_quantile(x, data_type):
return np.argmin([abs(dt-x) for dt in data_type])

def simple_quant(x, absmax, data_type):
c = 1/absmax
scaled = x*c
q = get_quantile(scaled, data_type)
return q

'''
data_type_np: [-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0000, 0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0000]
'''

quantized1 = simple_quant(weight_block[0], absmax, data_type_np)
# c = 1/absmax: 20.3750
# scaled: -0.0913
# quantized: 6

quantized2 = simple_quant(weight_block[1], absmax, data_type_np)
# c = 1/absmax: 20.3750
# scaled: 0.
# quantized: 7

Looking at the scaled value -0.0913 we can see that it equals to quantile 6

  • it is close to -0.911 but not exactly same due to the implementation

this is equal to the original quantized weight “0110”

Then dequantizing the weight is as follows:

def simple_dequant(x_q, absmax, data_type):
dq = data_type[x_q]
c = 1/absmax
return dq/c

dequantized = simple_dequant(quantized1, absmax, data_type_np)
# dq: data_type[6] -> -0.0911
# c: 20.3750
# DEQUANTIZED tensor(-0.0045)
# Before Quantization: -0.0045

dequantized = simple_dequant(quantized1, absmax, data_type_np)
# dq: data_type[7] -> 0.0000
# c: 20.3750
# DEQUANTIZED tensor(0.)
# Before Quantization: 0.

Sources:

--

--