mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat: Add Bitsandbytes quantization for transformer backend enhancement #1775 and fix: Transformer backend error on CUDA #1774 (#1823)
* fixes #1775 and #1774 Add BitsAndBytes Quantization and fixes embedding on CUDA devices * Manage 4bit and 8 bit quantization Manage different BitsAndBytes options with the quantization: parameter in yaml * fix compilation errors on non CUDA environment
This commit is contained in:
parent
a6b540737f
commit
3882130911
@ -30,6 +30,7 @@ dependencies:
|
||||
- async-timeout==4.0.3
|
||||
- attrs==23.1.0
|
||||
- bark==0.1.5
|
||||
- bitsandbytes==0.43.0
|
||||
- boto3==1.28.61
|
||||
- botocore==1.31.61
|
||||
- certifi==2023.7.22
|
||||
|
@ -23,7 +23,7 @@ if XPU:
|
||||
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModel, set_seed
|
||||
else:
|
||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed
|
||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@ -75,18 +75,50 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
A Result object that contains the result of the LoadModel operation.
|
||||
"""
|
||||
model_name = request.Model
|
||||
|
||||
compute = "auto"
|
||||
if request.F16Memory == True:
|
||||
compute=torch.bfloat16
|
||||
|
||||
self.CUDA = request.CUDA
|
||||
|
||||
device_map="cpu"
|
||||
|
||||
quantization = None
|
||||
|
||||
if self.CUDA:
|
||||
if request.Device:
|
||||
device_map=request.Device
|
||||
else:
|
||||
device_map="cuda:0"
|
||||
if request.Quantization == "bnb_4bit":
|
||||
quantization = BitsAndBytesConfig(
|
||||
load_in_4bit = True,
|
||||
bnb_4bit_compute_dtype = compute,
|
||||
bnb_4bit_quant_type = "nf4",
|
||||
bnb_4bit_use_double_quant = True,
|
||||
load_in_8bit = False,
|
||||
)
|
||||
elif request.Quantization == "bnb_8bit":
|
||||
quantization = BitsAndBytesConfig(
|
||||
load_in_4bit=False,
|
||||
bnb_4bit_compute_dtype = None,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if request.Type == "AutoModelForCausalLM":
|
||||
if XPU:
|
||||
if quantization == "xpu_4bit":
|
||||
xpu_4bit = True
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode,
|
||||
device_map="xpu", load_in_4bit=True)
|
||||
device_map="xpu", load_in_4bit=xpu_4bit)
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute)
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.CUDA = False
|
||||
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
||||
self.XPU = False
|
||||
|
||||
if XPU:
|
||||
@ -97,13 +129,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as err:
|
||||
print("Not using XPU:", err, file=sys.stderr)
|
||||
|
||||
if request.CUDA or torch.cuda.is_available():
|
||||
try:
|
||||
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
|
||||
self.model = self.model.to("cuda")
|
||||
self.CUDA = True
|
||||
except Exception as err:
|
||||
print("Not using CUDA:", err, file=sys.stderr)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
@ -130,13 +155,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
||||
|
||||
# Create word embeddings
|
||||
model_output = self.model(**encoded_input)
|
||||
if self.CUDA:
|
||||
encoded_input = encoded_input.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).detach().numpy()
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
|
||||
print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings[0])
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""
|
||||
@ -163,12 +192,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if XPU:
|
||||
inputs = inputs.to("xpu")
|
||||
|
||||
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
||||
|
||||
generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
# Remove prompt from response if present
|
||||
if request.Prompt in generated_text:
|
||||
generated_text = generated_text.replace(request.Prompt, "")
|
||||
outputs = self.model.generate(inputs,max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
|
||||
generated_text = self.tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
|
||||
|
||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user