2024-01-05 17:04:46 +00:00
|
|
|
package backend
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"os"
|
|
|
|
"regexp"
|
|
|
|
"strings"
|
|
|
|
"sync"
|
|
|
|
"unicode/utf8"
|
|
|
|
|
2024-03-01 15:19:53 +00:00
|
|
|
"github.com/go-skynet/LocalAI/core/config"
|
|
|
|
|
2024-01-05 17:04:46 +00:00
|
|
|
"github.com/go-skynet/LocalAI/pkg/gallery"
|
|
|
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
|
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
|
|
"github.com/go-skynet/LocalAI/pkg/utils"
|
|
|
|
)
|
|
|
|
|
|
|
|
type LLMResponse struct {
|
|
|
|
Response string // should this be []byte?
|
|
|
|
Usage TokenUsage
|
|
|
|
}
|
|
|
|
|
|
|
|
type TokenUsage struct {
|
|
|
|
Prompt int
|
|
|
|
Completion int
|
|
|
|
}
|
|
|
|
|
2024-03-01 15:19:53 +00:00
|
|
|
func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
2024-01-05 17:04:46 +00:00
|
|
|
modelFile := c.Model
|
2024-03-07 13:37:45 +00:00
|
|
|
threads := c.Threads
|
2024-03-13 09:05:30 +00:00
|
|
|
if *threads == 0 && o.Threads != 0 {
|
|
|
|
threads = &o.Threads
|
2024-03-07 13:37:45 +00:00
|
|
|
}
|
2024-01-05 17:04:46 +00:00
|
|
|
grpcOpts := gRPCModelOpts(c)
|
|
|
|
|
2024-01-23 07:56:36 +00:00
|
|
|
var inferenceModel grpc.Backend
|
2024-01-05 17:04:46 +00:00
|
|
|
var err error
|
|
|
|
|
|
|
|
opts := modelOpts(c, o, []model.Option{
|
|
|
|
model.WithLoadGRPCLoadModelOpts(grpcOpts),
|
2024-03-13 09:05:30 +00:00
|
|
|
model.WithThreads(uint32(*threads)), // some models uses this to allocate threads during startup
|
2024-01-05 17:04:46 +00:00
|
|
|
model.WithAssetDir(o.AssetsDestination),
|
|
|
|
model.WithModel(modelFile),
|
|
|
|
model.WithContext(o.Context),
|
|
|
|
})
|
|
|
|
|
|
|
|
if c.Backend != "" {
|
|
|
|
opts = append(opts, model.WithBackendString(c.Backend))
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
|
|
|
if o.AutoloadGalleries { // experimental
|
|
|
|
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
|
|
|
|
utils.ResetDownloadTimers()
|
|
|
|
// if we failed to load the model, we try to download it
|
|
|
|
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Backend == "" {
|
|
|
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
|
|
|
} else {
|
|
|
|
inferenceModel, err = loader.BackendLoader(opts...)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
|
|
|
fn := func() (LLMResponse, error) {
|
|
|
|
opts := gRPCPredictOpts(c, loader.ModelPath)
|
|
|
|
opts.Prompt = s
|
|
|
|
opts.Images = images
|
|
|
|
|
|
|
|
tokenUsage := TokenUsage{}
|
|
|
|
|
|
|
|
// check the per-model feature flag for usage, since tokenCallback may have a cost.
|
|
|
|
// Defaults to off as for now it is still experimental
|
|
|
|
if c.FeatureFlag.Enabled("usage") {
|
|
|
|
userTokenCallback := tokenCallback
|
|
|
|
if userTokenCallback == nil {
|
|
|
|
userTokenCallback = func(token string, usage TokenUsage) bool {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
|
|
|
|
if pErr == nil && promptInfo.Length > 0 {
|
|
|
|
tokenUsage.Prompt = int(promptInfo.Length)
|
|
|
|
}
|
|
|
|
|
|
|
|
tokenCallback = func(token string, usage TokenUsage) bool {
|
|
|
|
tokenUsage.Completion++
|
|
|
|
return userTokenCallback(token, tokenUsage)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if tokenCallback != nil {
|
|
|
|
ss := ""
|
|
|
|
|
|
|
|
var partialRune []byte
|
|
|
|
err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) {
|
|
|
|
partialRune = append(partialRune, chars...)
|
|
|
|
|
|
|
|
for len(partialRune) > 0 {
|
|
|
|
r, size := utf8.DecodeRune(partialRune)
|
|
|
|
if r == utf8.RuneError {
|
|
|
|
// incomplete rune, wait for more bytes
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
|
|
|
tokenCallback(string(r), tokenUsage)
|
|
|
|
ss += string(r)
|
|
|
|
|
|
|
|
partialRune = partialRune[size:]
|
|
|
|
}
|
|
|
|
})
|
|
|
|
return LLMResponse{
|
|
|
|
Response: ss,
|
|
|
|
Usage: tokenUsage,
|
|
|
|
}, err
|
|
|
|
} else {
|
|
|
|
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
|
|
|
reply, err := inferenceModel.Predict(ctx, opts)
|
|
|
|
if err != nil {
|
|
|
|
return LLMResponse{}, err
|
|
|
|
}
|
|
|
|
return LLMResponse{
|
|
|
|
Response: string(reply.Message),
|
|
|
|
Usage: tokenUsage,
|
|
|
|
}, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return fn, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
|
|
|
var mu sync.Mutex = sync.Mutex{}
|
|
|
|
|
2024-03-01 15:19:53 +00:00
|
|
|
func Finetune(config config.BackendConfig, input, prediction string) string {
|
2024-01-05 17:04:46 +00:00
|
|
|
if config.Echo {
|
|
|
|
prediction = input + prediction
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, c := range config.Cutstrings {
|
|
|
|
mu.Lock()
|
|
|
|
reg, ok := cutstrings[c]
|
|
|
|
if !ok {
|
|
|
|
cutstrings[c] = regexp.MustCompile(c)
|
|
|
|
reg = cutstrings[c]
|
|
|
|
}
|
|
|
|
mu.Unlock()
|
|
|
|
prediction = reg.ReplaceAllString(prediction, "")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, c := range config.TrimSpace {
|
|
|
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, c := range config.TrimSuffix {
|
|
|
|
prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c))
|
|
|
|
}
|
|
|
|
return prediction
|
|
|
|
}
|