189 lines
14 KiB
Python
189 lines
14 KiB
Python
|
from torch.nn import Linear
|
||
|
from torch.nn.parameter import Parameter
|
||
|
|
||
|
import bz2
|
||
|
import torch
|
||
|
import base64
|
||
|
import ctypes
|
||
|
from transformers.utils import logging
|
||
|
|
||
|
from typing import List
|
||
|
from functools import partial
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
try:
|
||
|
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
||
|
|
||
|
class Kernel:
|
||
|
def __init__(self, code: bytes, function_names: List[str]):
|
||
|
self.code = code
|
||
|
self._function_names = function_names
|
||
|
self._cmodule = LazyKernelCModule(self.code)
|
||
|
|
||
|
for name in self._function_names:
|
||
|
setattr(self, name, KernelFunction(self._cmodule, name))
|
||
|
|
||
|
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWY
|
||
|
|
||
|
kernels = Kernel(
|
||
|
bz2.decompress(base64.b64decode(quantization_code)),
|
||
|
[
|
||
|
"int4WeightCompression",
|
||
|
"int4WeightExtractionFloat",
|
||
|
"int4WeightExtractionHalf",
|
||
|
"int8WeightExtractionFloat",
|
||
|
"int8WeightExtractionHalf",
|
||
|
],
|
||
|
)
|
||
|
except Exception as exception:
|
||
|
kernels = None
|
||
|
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
||
|
|
||
|
|
||
|
class W8A16Linear(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
||
|
ctx.inp_shape = inp.size()
|
||
|
ctx.weight_bit_width = weight_bit_width
|
||
|
out_features = quant_w.size(0)
|
||
|
inp = inp.contiguous().view(-1, inp.size(-1))
|
||
|
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
||
|
ctx.weight_shape = weight.size()
|
||
|
output = inp.mm(weight.t())
|
||
|
ctx.save_for_backward(inp, quant_w, scale_w)
|
||
|
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output: torch.Tensor):
|
||
|
inp, quant_w, scale_w = ctx.saved_tensors
|
||
|
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
||
|
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
||
|
grad_input = grad_output.mm(weight)
|
||
|
grad_weight = grad_output.t().mm(inp)
|
||
|
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
||
|
|
||
|
|
||
|
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
||
|
with torch.cuda.device(weight.device):
|
||
|
n, m = weight.size(0), weight.size(1)
|
||
|
assert m % 2 == 0
|
||
|
m = m // 2
|
||
|
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
||
|
stream = torch.cuda.current_stream()
|
||
|
|
||
|
gridDim = (n, 1, 1)
|
||
|
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
||
|
|
||
|
kernels.int4WeightCompression(
|
||
|
gridDim,
|
||
|
blockDim,
|
||
|
0,
|
||
|
stream,
|
||
|
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
||
|
)
|
||
|
return out
|
||
|
|
||
|
|
||
|
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
||
|
assert scale_list.dtype in [torch.half, torch.bfloat16]
|
||
|
assert weight.dtype in [torch.int8]
|
||
|
if source_bit_width == 8:
|
||
|
return weight.to(scale_list.dtype) * scale_list[:, None]
|
||
|
elif source_bit_width == 4:
|
||
|
func = (
|
||
|
kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
|
||
|
)
|
||
|
else:
|
||
|
assert False, "Unsupported bit-width"
|
||
|
|
||
|
with torch.cuda.device(weight.device):
|
||
|
n, m = weight.size(0), weight.size(1)
|
||
|
out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
|
||
|
stream = torch.cuda.current_stream()
|
||
|
|
||
|
gridDim = (n, 1, 1)
|
||
|
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
||
|
|
||
|
func(
|
||
|
gridDim,
|
||
|
blockDim,
|
||
|
0,
|
||
|
stream,
|
||
|
[
|
||
|
ctypes.c_void_p(weight.data_ptr()),
|
||
|
ctypes.c_void_p(scale_list.data_ptr()),
|
||
|
ctypes.c_void_p(out.data_ptr()),
|
||
|
ctypes.c_int32(n),
|
||
|
ctypes.c_int32(m),
|
||
|
],
|
||
|
)
|
||
|
return out
|
||
|
|
||
|
|
||
|
class QuantizedLinear(torch.nn.Module):
|
||
|
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
||
|
**kwargs):
|
||
|
super().__init__()
|
||
|
self.weight_bit_width = weight_bit_width
|
||
|
|
||
|
shape = weight.shape
|
||
|
|
||
|
if weight is None or empty_init:
|
||
|
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
||
|
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
||
|
else:
|
||
|
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
||
|
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
||
|
if weight_bit_width == 4:
|
||
|
self.weight = compress_int4_weight(self.weight)
|
||
|
|
||
|
self.weight = Parameter(self.weight.to(device), requires_grad=False)
|
||
|
self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
|
||
|
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
||
|
|
||
|
def forward(self, input):
|
||
|
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
||
|
if self.bias is not None:
|
||
|
output = output + self.bias
|
||
|
return output
|
||
|
|
||
|
|
||
|
def quantize(model, weight_bit_width, empty_init=False, device=None):
|
||
|
"""Replace fp16 linear with quantized linear"""
|
||
|
for layer in model.layers:
|
||
|
layer.self_attention.query_key_value = QuantizedLinear(
|
||
|
weight_bit_width=weight_bit_width,
|
||
|
weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
|
||
|
bias=layer.self_attention.query_key_value.bias,
|
||
|
dtype=layer.self_attention.query_key_value.weight.dtype,
|
||
|
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
||
|
empty_init=empty_init
|
||
|
)
|
||
|
layer.self_attention.dense = QuantizedLinear(
|
||
|
weight_bit_width=weight_bit_width,
|
||
|
weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
|
||
|
bias=layer.self_attention.dense.bias,
|
||
|
dtype=layer.self_attention.dense.weight.dtype,
|
||
|
device=layer.self_attention.dense.weight.device if device is None else device,
|
||
|
empty_init=empty_init
|
||
|
)
|
||
|
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
||
|
weight_bit_width=weight_bit_width,
|
||
|
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
||
|
bias=layer.mlp.dense_h_to_4h.bias,
|
||
|
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
||
|
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
||
|
empty_init=empty_init
|
||
|
)
|
||
|
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
||
|
weight_bit_width=weight_bit_width,
|
||
|
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
||
|
bias=layer.mlp.dense_4h_to_h.bias,
|
||
|
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
||
|
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
||
|
empty_init=empty_init
|
||
|
)
|
||
|
|
||
|
return model
|