mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat(transformers): support also text generation (#1630)
* feat(transformers): support also text generation Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * embedded: set seed -1 --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
d5d82ba344
commit
5e335eaead
@ -15,7 +15,7 @@ import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
import torch.cuda
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
@ -70,12 +70,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
try:
|
||||
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # trust_remote_code is needed to use the encode method with embeddings models like jinai-v2
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
if request.CUDA:
|
||||
if request.CUDA or torch.cuda.is_available():
|
||||
try:
|
||||
# TODO: also tensorflow, make configurable
|
||||
import torch.cuda
|
||||
if torch.cuda.is_available():
|
||||
print("Loading model", model_name, "to CUDA.", file=sys.stderr)
|
||||
self.model = self.model.to("cuda")
|
||||
except Exception as err:
|
||||
@ -113,6 +109,47 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print("Embeddings:", sentence_embeddings, file=sys.stderr)
|
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
if request.TopP == 0:
|
||||
request.TopP = 0.9
|
||||
|
||||
max_tokens = 200
|
||||
if request.Tokens > 0:
|
||||
max_tokens = request.Tokens
|
||||
|
||||
inputs = self.tokenizer.tokenizer(request.Prompt, return_tensors="pt").input_ids
|
||||
outputs = self.model.generate(inputs,max_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP)
|
||||
|
||||
generated_text = self.tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
# Remove prompt from response if present
|
||||
if request.Prompt in generated_text:
|
||||
generated_text = generated_text.replace(request.Prompt, "")
|
||||
|
||||
return backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8'))
|
||||
|
||||
def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results.
|
||||
|
||||
Args:
|
||||
request: The predict stream request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The predict stream result.
|
||||
"""
|
||||
yield self.Predict(request, context)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
|
@ -5,6 +5,7 @@ parameters:
|
||||
temperature: 0.2
|
||||
top_k: 40
|
||||
top_p: 0.95
|
||||
seed: -1
|
||||
template:
|
||||
chat_message: |
|
||||
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}}
|
||||
|
@ -17,6 +17,7 @@ parameters:
|
||||
temperature: 0.2
|
||||
top_k: 40
|
||||
top_p: 0.95
|
||||
seed: -1
|
||||
|
||||
template:
|
||||
chat: |
|
||||
|
@ -5,6 +5,7 @@ parameters:
|
||||
temperature: 0.2
|
||||
top_k: 40
|
||||
top_p: 0.95
|
||||
seed: -1
|
||||
template:
|
||||
chat_message: |
|
||||
<|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}}
|
||||
|
@ -4,6 +4,7 @@ parameters:
|
||||
model: huggingface://TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/mixtral-8x7b-instruct-v0.1.Q2_K.gguf
|
||||
temperature: 0.2
|
||||
top_k: 40
|
||||
seed: -1
|
||||
top_p: 0.95
|
||||
template:
|
||||
chat: &chat |
|
||||
|
@ -4,6 +4,7 @@ parameters:
|
||||
model: huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q8_0.gguf
|
||||
temperature: 0.2
|
||||
top_k: 40
|
||||
seed: -1
|
||||
top_p: 0.95
|
||||
template:
|
||||
chat_message: |
|
||||
|
@ -10,6 +10,7 @@ parameters:
|
||||
temperature: 0.2
|
||||
top_k: 40
|
||||
top_p: 0.95
|
||||
seed: -1
|
||||
template:
|
||||
chat: &template |
|
||||
Instruct: {{.Input}}
|
||||
|
Loading…
Reference in New Issue
Block a user