feat: config files and SSE (#83)

Signed-off-by: mudler <mudler@mocaccino.org>
Signed-off-by: Tyler Gillson <tyler.gillson@gmail.com>
Co-authored-by: Tyler Gillson <tyler.gillson@gmail.com>
This commit is contained in:
Ettore Di Giacinto 2023-04-27 06:18:18 +02:00 committed by GitHub
parent 4e2061636e
commit c806eae0de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 984 additions and 419 deletions

View File

@ -1 +1,2 @@
models models
examples/chatbot-ui/models

3
.gitignore vendored
View File

@ -10,6 +10,5 @@ local-ai
!charts/* !charts/*
# Ignore models # Ignore models
models/*.bin models/*
models/ggml-*
test-models/ test-models/

View File

@ -102,9 +102,11 @@ run: prepare
test-models/testmodel: test-models/testmodel:
mkdir test-models mkdir test-models
wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel
cp tests/fixtures/* test-models
test: prepare test-models/testmodel test: prepare test-models/testmodel
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} MODELS_PATH=$(abspath ./)/test-models $(GOCMD) test -v ./... cp tests/fixtures/* test-models
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) test -v -timeout 20m ./...
## Help: ## Help:
help: ## Show this help. help: ## Show this help.

View File

@ -50,6 +50,9 @@ git clone https://github.com/go-skynet/LocalAI
cd LocalAI cd LocalAI
# (optional) Checkout a specific LocalAI tag
# git checkout -b build <TAG>
# copy your models to models/ # copy your models to models/
cp your-model.bin models/ cp your-model.bin models/
@ -80,6 +83,9 @@ git clone https://github.com/go-skynet/LocalAI
cd LocalAI cd LocalAI
# (optional) Checkout a specific LocalAI tag
# git checkout -b build <TAG>
# Download gpt4all-j to models/ # Download gpt4all-j to models/
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j
@ -106,6 +112,12 @@ curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/jso
``` ```
</details> </details>
To build locally, run `make build` (see below).
## Other examples
To see other examples on how to integrate with other projects, see: [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/).
## Prompt templates ## Prompt templates
The API doesn't inject a default prompt for talking to the model. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release. The API doesn't inject a default prompt for talking to the model. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release.
@ -169,6 +181,9 @@ Once the server is running, you can start making requests to it using HTTP, usin
</details> </details>
## Advanced configuration
### Supported OpenAI API endpoints ### Supported OpenAI API endpoints
You can check out the [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create). You can check out the [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create).
@ -223,22 +238,11 @@ curl http://localhost:8080/v1/models
</details> </details>
## Using other models
gpt4all (https://github.com/nomic-ai/gpt4all) works as well, however the original model needs to be converted (same applies for old alpaca models, too):
```bash
wget -O tokenizer.model https://huggingface.co/decapoda-research/llama-30b-hf/resolve/main/tokenizer.model
mkdir models
cp gpt4all.. models/
git clone https://gist.github.com/eiz/828bddec6162a023114ce19146cb2b82
pip install sentencepiece
python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model
# There will be a new model with the ".tmp" extension, you have to use that one!
```
## Helm Chart Installation (run LocalAI in Kubernetes) ## Helm Chart Installation (run LocalAI in Kubernetes)
LocalAI can be installed inside Kubernetes with helm.
<details>
The local-ai Helm chart supports two options for the LocalAI server's models directory: The local-ai Helm chart supports two options for the LocalAI server's models directory:
1. Basic deployment with no persistent volume. You must manually update the Deployment to configure your own models directory. 1. Basic deployment with no persistent volume. You must manually update the Deployment to configure your own models directory.
@ -258,6 +262,12 @@ The local-ai Helm chart supports two options for the LocalAI server's models dir
``` ```
This will update the local-ai Deployment to mount the PV that was provisioned by the DataVolume. This will update the local-ai Deployment to mount the PV that was provisioned by the DataVolume.
</details>
## Blog posts
- https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65
## Windows compatibility ## Windows compatibility
It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/LocalAI/issues/2 It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/LocalAI/issues/2
@ -335,17 +345,25 @@ AutoGPT currently doesn't allow to set a different API URL, but there is a PR op
</details> </details>
## Projects already using LocalAI to run local models
Feel free to open up a PR to get your project listed!
- [Kairos](https://github.com/kairos-io/kairos)
- [k8sgpt](https://github.com/k8sgpt-ai/k8sgpt#running-local-models)
## Short-term roadmap ## Short-term roadmap
- [x] Mimic OpenAI API (https://github.com/go-skynet/LocalAI/issues/10) - [x] Mimic OpenAI API (https://github.com/go-skynet/LocalAI/issues/10)
- [ ] Binary releases (https://github.com/go-skynet/LocalAI/issues/6) - [ ] Binary releases (https://github.com/go-skynet/LocalAI/issues/6)
- [ ] Upstream our golang bindings to llama.cpp (https://github.com/ggerganov/llama.cpp/issues/351) - [ ] Upstream our golang bindings to llama.cpp (https://github.com/ggerganov/llama.cpp/issues/351) and gpt4all
- [x] Multi-model support - [x] Multi-model support
- [ ] Have a webUI! - [ ] Have a webUI!
- [ ] Allow configuration of defaults for models. - [ ] Allow configuration of defaults for models.
- [ ] Enable automatic downloading of models from a curated gallery, with only free-licensed models. - [ ] Enable automatic downloading of models from a curated gallery, with only free-licensed models.
[![LocalAI Star history Chart](https://api.star-history.com/svg?repos=go-skynet/LocalAI&type=Date)](https://star-history.com/#go-skynet/LocalAI&Date)
## License ## License
MIT MIT

View File

@ -1,16 +1,9 @@
package api package api
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"strings"
"sync"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/recover" "github.com/gofiber/fiber/v2/middleware/recover"
@ -18,375 +11,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// APIError provides error information returned by the OpenAI API. func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
type APIError struct {
Code any `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
}
type ErrorResponse struct {
Error *APIError `json:"error,omitempty"`
}
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"chat.completion,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type Choice struct {
Index int `json:"index,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message *Message `json:"message,omitempty"`
Text string `json:"text,omitempty"`
}
type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
type OpenAIRequest struct {
Model string `json:"model"`
// Prompt is read only by completion API calls
Prompt string `json:"prompt"`
Stop string `json:"stop"`
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages"`
Echo bool `json:"echo"`
// Common options between all the API calls
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
Temperature float64 `json:"temperature"`
Maxtokens int `json:"max_tokens"`
N int `json:"n"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch"`
F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty"`
Keep int `json:"n_keep"`
Seed int `json:"seed"`
}
// https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM
input := new(OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
modelFile := input.Model
received, _ := json.Marshal(input)
log.Debug().Msgf("Request received: %s", string(received))
// Set model from bearer token, if available
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelFile == "" {
models, _ := loader.ListModels()
if len(models) > 0 {
modelFile = models[0]
log.Debug().Msgf("No model specified, using: %s", modelFile)
}
}
// If no model is found or specified, we bail out
if modelFile == "" && !bearerExists {
return fmt.Errorf("no model specified")
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelFile = bearer
}
// Try to load the model
var llamaerr, gpt2err, gptjerr, stableerr error
llamaOpts := []llama.ModelOption{}
if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx))
}
if f16 {
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil {
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
if stableerr != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
}
}
}
}
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
// Set the parameters for the language model prediction
topP := input.TopP
if topP == 0 {
topP = 0.7
}
topK := input.TopK
if topK == 0 {
topK = 80
}
temperature := input.Temperature
if temperature == 0 {
temperature = 0.9
}
tokens := input.Maxtokens
if tokens == 0 {
tokens = 512
}
predInput := input.Prompt
if chat {
mess := []string{}
// TODO: encode roles
for _, i := range input.Messages {
mess = append(mess, i.Content)
}
predInput = strings.Join(mess, "\n")
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(modelFile, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
result := []Choice{}
n := input.N
if input.N == 0 {
n = 1
}
var predFunc func() (string, error)
switch {
case stableLMModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return stableLMModel.Predict(
predInput,
predictOptions...,
)
}
case gpt2Model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return gpt2Model.Predict(
predInput,
predictOptions...,
)
}
case gptModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gptj.PredictOption{
gptj.SetTemperature(temperature),
gptj.SetTopP(topP),
gptj.SetTopK(topK),
gptj.SetTokens(tokens),
gptj.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gptj.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gptj.SetSeed(input.Seed))
}
return gptModel.Predict(
predInput,
predictOptions...,
)
}
case model != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
}
if debug {
predictOptions = append(predictOptions, llama.Debug)
}
if input.Stop != "" {
predictOptions = append(predictOptions, llama.SetStopWords(input.Stop))
}
if input.RepeatPenalty != 0 {
predictOptions = append(predictOptions, llama.SetPenalty(input.RepeatPenalty))
}
if input.Keep != 0 {
predictOptions = append(predictOptions, llama.SetNKeep(input.Keep))
}
if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
if input.Seed != 0 {
predictOptions = append(predictOptions, llama.SetSeed(input.Seed))
}
return model.Predict(
predInput,
predictOptions...,
)
}
}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return err
}
if input.Echo {
prediction = predInput + prediction
}
if chat {
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
} else {
result = append(result, Choice{Text: prediction})
}
}
jsonResult, _ := json.Marshal(result)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
})
}
}
func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
return err
}
dataModels := []OpenAIModel{}
for _, m := range models {
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,
})
}
}
func App(loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App {
zerolog.SetGlobalLevel(zerolog.InfoLevel) zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug { if debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
@ -415,23 +40,35 @@ func App(loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disab
}, },
}) })
cm := make(ConfigMerger)
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}
if configFile != "" {
if err := cm.LoadConfigFile(configFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}
if debug {
for k, v := range cm {
log.Debug().Msgf("Model: %s (config: %+v)", k, v)
}
}
// Default middleware config // Default middleware config
app.Use(recover.New()) app.Use(recover.New())
app.Use(cors.New()) app.Use(cors.New())
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mu := map[string]*sync.Mutex{}
var mumutex = &sync.Mutex{}
// openAI compatible API endpoint // openAI compatible API endpoint
app.Post("/v1/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/v1/chat/completions", openAIEndpoint(cm, true, debug, loader, threads, ctxSize, f16))
app.Post("/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/chat/completions", openAIEndpoint(cm, true, debug, loader, threads, ctxSize, f16))
app.Post("/v1/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/v1/completions", openAIEndpoint(cm, false, debug, loader, threads, ctxSize, f16))
app.Post("/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/completions", openAIEndpoint(cm, false, debug, loader, threads, ctxSize, f16))
app.Get("/v1/models", listModels(loader)) app.Get("/v1/models", listModels(loader, cm))
app.Get("/models", listModels(loader)) app.Get("/models", listModels(loader, cm))
return app return app
} }

View File

@ -21,7 +21,7 @@ var _ = Describe("API test", func() {
Context("API query", func() { Context("API query", func() {
BeforeEach(func() { BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App(modelLoader, 1, 512, false, false, true) app = App("", modelLoader, 1, 512, false, true, true)
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -40,7 +40,7 @@ var _ = Describe("API test", func() {
It("returns the models list", func() { It("returns the models list", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(1)) Expect(len(models.Models)).To(Equal(3))
Expect(models.Models[0].ID).To(Equal("testmodel")) Expect(models.Models[0].ID).To(Equal("testmodel"))
}) })
It("can generate completions", func() { It("can generate completions", func() {
@ -49,10 +49,73 @@ var _ = Describe("API test", func() {
Expect(len(resp.Choices)).To(Equal(1)) Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty()) Expect(resp.Choices[0].Text).ToNot(BeEmpty())
}) })
It("can generate chat completions ", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate completions from model configs", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: "abcdedfghikl"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).ToNot(BeEmpty())
})
It("can generate chat completions from model configs", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("returns errors", func() { It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: llama: model does not exist")) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: llama: model does not exist"))
}) })
})
Context("Config file", func() {
BeforeEach(func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 1, 512, false, true, true)
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
app.Shutdown()
})
It("can generate chat completions from config file", func() {
models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(5))
Expect(models.Models[0].ID).To(Equal("testmodel"))
})
It("can generate chat completions from config file", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("can generate chat completions from config file", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
}) })
}) })

100
api/config.go Normal file
View File

@ -0,0 +1,100 @@
package api
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
type Config struct {
OpenAIRequest `yaml:"parameters"`
Name string `yaml:"name"`
StopWords []string `yaml:"stopwords"`
Cutstrings []string `yaml:"cutstrings"`
TrimSpace []string `yaml:"trimspace"`
ContextSize int `yaml:"context_size"`
F16 bool `yaml:"f16"`
Threads int `yaml:"threads"`
Debug bool `yaml:"debug"`
Roles map[string]string `yaml:"roles"`
TemplateConfig TemplateConfig `yaml:"template"`
}
type TemplateConfig struct {
Completion string `yaml:"completion"`
Chat string `yaml:"chat"`
}
type ConfigMerger map[string]Config
func ReadConfigFile(file string) ([]*Config, error) {
c := &[]*Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return *c, nil
}
func ReadConfig(file string) (*Config, error) {
c := &Config{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
}
return c, nil
}
func (cm ConfigMerger) LoadConfigFile(file string) error {
c, err := ReadConfigFile(file)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}
for _, cc := range c {
cm[cc.Name] = *cc
}
return nil
}
func (cm ConfigMerger) LoadConfig(file string) error {
c, err := ReadConfig(file)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}
cm[c.Name] = *c
return nil
}
func (cm ConfigMerger) LoadConfigs(path string) error {
files, err := ioutil.ReadDir(path)
if err != nil {
return err
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") {
continue
}
c, err := ReadConfig(filepath.Join(path, file.Name()))
if err == nil {
cm[c.Name] = *c
}
}
return nil
}

396
api/openai.go Normal file
View File

@ -0,0 +1,396 @@
package api
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// APIError provides error information returned by the OpenAI API.
type APIError struct {
Code any `json:"code,omitempty"`
Message string `json:"message"`
Param *string `json:"param,omitempty"`
Type string `json:"type"`
}
type ErrorResponse struct {
Error *APIError `json:"error,omitempty"`
}
type OpenAIResponse struct {
Created int `json:"created,omitempty"`
Object string `json:"object,omitempty"`
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices,omitempty"`
}
type Choice struct {
Index int `json:"index,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
Text string `json:"text,omitempty"`
}
type Message struct {
Role string `json:"role,omitempty" yaml:"role"`
Content string `json:"content,omitempty" yaml:"content"`
}
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
type OpenAIRequest struct {
Model string `json:"model" yaml:"model"`
// Prompt is read only by completion API calls
Prompt string `json:"prompt" yaml:"prompt"`
Stop string `json:"stop" yaml:"stop"`
// Messages is read only by chat/completion API calls
Messages []Message `json:"messages" yaml:"messages"`
Stream bool `json:"stream"`
Echo bool `json:"echo"`
// Common options between all the API calls
TopP float64 `json:"top_p" yaml:"top_p"`
TopK int `json:"top_k" yaml:"top_k"`
Temperature float64 `json:"temperature" yaml:"temperature"`
Maxtokens int `json:"max_tokens" yaml:"max_tokens"`
N int `json:"n"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch" yaml:"batch"`
F16 bool `json:"f16" yaml:"f16"`
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
Keep int `json:"n_keep" yaml:"n_keep"`
Seed int `json:"seed" yaml:"seed"`
}
func defaultRequest(modelFile string) OpenAIRequest {
return OpenAIRequest{
TopP: 0.7,
TopK: 80,
Maxtokens: 512,
Temperature: 0.9,
Model: modelFile,
}
}
func updateConfig(config *Config, input *OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != 0 {
config.TopK = input.TopK
}
if input.TopP != 0 {
config.TopP = input.TopP
}
if input.Temperature != 0 {
config.Temperature = input.Temperature
}
if input.Maxtokens != 0 {
config.Maxtokens = input.Maxtokens
}
if input.Stop != "" {
config.StopWords = append(config.StopWords, input.Stop)
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.F16 {
config.F16 = input.F16
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != 0 {
config.Seed = input.Seed
}
}
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
var mu sync.Mutex = sync.Mutex{}
// https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(cm ConfigMerger, chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(OpenAIRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
if input.Stream {
log.Debug().Msgf("Stream request received")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
c.Set("Content-Type", "text/event-stream; charset=utf-8")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
modelFile := input.Model
received, _ := json.Marshal(input)
log.Debug().Msgf("Request received: %s", string(received))
// Set model from bearer token, if available
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ")
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer)
// If no model was specified, take the first available
if modelFile == "" && !bearerExists {
models, _ := loader.ListModels()
if len(models) > 0 {
modelFile = models[0]
log.Debug().Msgf("No model specified, using: %s", modelFile)
} else {
log.Debug().Msgf("No model specified, returning error")
return fmt.Errorf("no model specified")
}
}
// If a model is found in bearer token takes precedence
if bearerExists {
log.Debug().Msgf("Using model from bearer token: %s", bearer)
modelFile = bearer
}
// Load a config file if present after the model name
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cm.LoadConfig(modelConfig); err != nil {
return fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
}
var config *Config
cfg, exists := cm[modelFile]
if !exists {
config = &Config{
OpenAIRequest: defaultRequest(modelFile),
}
} else {
config = &cfg
}
// Set the parameters for the language model prediction
updateConfig(config, input)
if threads != 0 {
config.Threads = threads
}
if ctx != 0 {
config.ContextSize = ctx
}
if f16 {
config.F16 = true
}
if debug {
config.Debug = true
}
log.Debug().Msgf("Parameter Config: %+v", config)
predInput := input.Prompt
if chat {
mess := []string{}
for _, i := range input.Messages {
r := config.Roles[i.Role]
if r == "" {
r = i.Role
}
content := fmt.Sprint(r, " ", i.Content)
mess = append(mess, content)
}
predInput = strings.Join(mess, "\n")
}
templateFile := config.Model
if config.TemplateConfig.Chat != "" && chat {
templateFile = config.TemplateConfig.Chat
}
if config.TemplateConfig.Completion != "" && !chat {
templateFile = config.TemplateConfig.Completion
}
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct {
Input string
}{Input: predInput})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
result := []Choice{}
n := input.N
if input.N == 0 {
n = 1
}
// get the model function to call for the result
predFunc, err := ModelInference(predInput, loader, *config)
if err != nil {
return err
}
finetunePrediction := func(prediction string) string {
if config.Echo {
prediction = predInput + prediction
}
for _, c := range config.Cutstrings {
mu.Lock()
reg, ok := cutstrings[c]
if !ok {
cutstrings[c] = regexp.MustCompile(c)
reg = cutstrings[c]
}
mu.Unlock()
prediction = reg.ReplaceAllString(prediction, "")
}
for _, c := range config.TrimSpace {
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
}
return prediction
}
for i := 0; i < n; i++ {
prediction, err := predFunc()
if err != nil {
return err
}
prediction = finetunePrediction(prediction)
if chat {
if input.Stream {
result = append(result, Choice{Delta: &Message{Role: "assistant", Content: prediction}})
} else {
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
}
} else {
result = append(result, Choice{Text: prediction})
}
}
resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
}
if input.Stream && chat {
resp.Object = "chat.completion.chunk"
} else if chat {
resp.Object = "chat.completion"
} else {
resp.Object = "text_completion"
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
if input.Stream {
log.Debug().Msgf("Handling stream request")
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "event: data\n")
w.Flush()
fmt.Fprintf(w, "data: %s\n\n", jsonResult)
w.Flush()
fmt.Fprintf(w, "event: data\n")
w.Flush()
resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{Choice{FinishReason: "stop"}},
}
respData, _ := json.Marshal(resp)
fmt.Fprintf(w, "data: %s\n\n", respData)
w.Flush()
// fmt.Fprintf(w, "data: [DONE]\n\n")
// w.Flush()
}))
return nil
} else {
// Return the prediction in the response body
return c.JSON(resp)
}
}
}
func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
return err
}
var mm map[string]interface{} = map[string]interface{}{}
dataModels := []OpenAIModel{}
for _, m := range models {
mm[m] = nil
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
for k := range cm {
if _, exists := mm[k]; !exists {
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"})
}
}
return c.JSON(struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,
})
}
}

188
api/prediction.go Normal file
View File

@ -0,0 +1,188 @@
package api
import (
"fmt"
"sync"
model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
)
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (string, error), error) {
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM
modelFile := c.Model
// Try to load the model
var llamaerr, gpt2err, gptjerr, stableerr error
llamaOpts := []llama.ModelOption{}
if c.ContextSize != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize))
}
if c.F16 {
llamaOpts = append(llamaOpts, llama.EnableF16Memory)
}
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...)
if llamaerr != nil {
gptModel, gptjerr = loader.LoadGPTJModel(modelFile)
if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil {
stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
if stableerr != nil {
return nil, fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
}
}
}
}
var fn func() (string, error)
switch {
case stableLMModel != nil:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return stableLMModel.Predict(
s,
predictOptions...,
)
}
case gpt2Model != nil:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}
return gpt2Model.Predict(
s,
predictOptions...,
)
}
case gptModel != nil:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gptj.PredictOption{
gptj.SetTemperature(c.Temperature),
gptj.SetTopP(c.TopP),
gptj.SetTopK(c.TopK),
gptj.SetTokens(c.Maxtokens),
gptj.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, gptj.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, gptj.SetSeed(c.Seed))
}
return gptModel.Predict(
s,
predictOptions...,
)
}
case model != nil:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(c.Temperature),
llama.SetTopP(c.TopP),
llama.SetTopK(c.TopK),
llama.SetTokens(c.Maxtokens),
llama.SetThreads(c.Threads),
}
if c.Debug {
predictOptions = append(predictOptions, llama.Debug)
}
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...))
if c.RepeatPenalty != 0 {
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty))
}
if c.Keep != 0 {
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep))
}
if c.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(c.Batch))
}
if c.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if c.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
if c.Seed != 0 {
predictOptions = append(predictOptions, llama.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
}
return func() (string, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()
return fn()
}, nil
}

11
examples/README.md Normal file
View File

@ -0,0 +1,11 @@
# Examples
Here is a list of projects that can easily be integrated with the LocalAI backend.
## Projects
- [chatbot-ui](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui/) (by [@mudler](https://github.com/mudler))
## Want to contribute?
Create an issue, and put `Example: <description>` in the title! We will post your examples here.

View File

@ -0,0 +1,26 @@
# chatbot-ui
Example of integration with [mckaywrigley/chatbot-ui](https://github.com/mckaywrigley/chatbot-ui).
![Screenshot from 2023-04-26 23-59-55](https://user-images.githubusercontent.com/2420543/234715439-98d12e03-d3ce-4f94-ab54-2b256808e05e.png)
## Setup
```bash
# Clone LocalAI
git clone https://github.com/go-skynet/LocalAI
cd LocalAI/examples/chatbot-ui
# (optional) Checkout a specific LocalAI tag
# git checkout -b build <TAG>
# Download gpt4all-j to models/
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j
# start with docker-compose
docker compose up -d --build
```
Open http://localhost:3000 for the Web UI.

View File

@ -0,0 +1,24 @@
version: '3.6'
services:
api:
image: quay.io/go-skynet/local-ai:latest
build:
context: ../../
dockerfile: Dockerfile
ports:
- 8080:8080
environment:
- DEBUG=true
- MODELS_PATH=/models
volumes:
- ./models:/models:cached
command: ["/usr/bin/local-ai" ]
chatgpt:
image: ghcr.io/mckaywrigley/chatbot-ui:main
ports:
- 3000:3000
environment:
- 'OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXX'
- 'OPENAI_API_HOST=http://api:8080'

View File

@ -0,0 +1 @@
{{.Input}}

View File

@ -0,0 +1,17 @@
name: gpt-3.5-turbo
parameters:
model: ggml-gpt4all-j
top_k: 80
temperature: 0.2
top_p: 0.7
context_size: 1024
threads: 14
stopwords:
- "HUMAN:"
- "GPT:"
roles:
user: " "
system: " "
template:
completion: completion
chat: gpt4all

View File

@ -0,0 +1,4 @@
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
{{.Input}}
### Response:

View File

@ -50,6 +50,11 @@ func main() {
EnvVars: []string{"MODELS_PATH"}, EnvVars: []string{"MODELS_PATH"},
Value: path, Value: path,
}, },
&cli.StringFlag{
Name: "config-file",
DefaultText: "Config file",
EnvVars: []string{"CONFIG_FILE"},
},
&cli.StringFlag{ &cli.StringFlag{
Name: "address", Name: "address",
DefaultText: "Bind address for the API server.", DefaultText: "Bind address for the API server.",
@ -80,7 +85,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
UsageText: `local-ai [options]`, UsageText: `local-ai [options]`,
Copyright: "go-skynet authors", Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
return api.App(model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address")) return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address"))
}, },
} }

View File

@ -18,7 +18,7 @@ import (
) )
type ModelLoader struct { type ModelLoader struct {
modelPath string ModelPath string
mu sync.Mutex mu sync.Mutex
models map[string]*llama.LLama models map[string]*llama.LLama
@ -31,7 +31,7 @@ type ModelLoader struct {
func NewModelLoader(modelPath string) *ModelLoader { func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{ return &ModelLoader{
modelPath: modelPath, ModelPath: modelPath,
gpt2models: make(map[string]*gpt2.GPT2), gpt2models: make(map[string]*gpt2.GPT2),
gptmodels: make(map[string]*gptj.GPTJ), gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM), gptstablelmmodels: make(map[string]*gpt2.StableLM),
@ -41,12 +41,12 @@ func NewModelLoader(modelPath string) *ModelLoader {
} }
func (ml *ModelLoader) ExistsInModelPath(s string) bool { func (ml *ModelLoader) ExistsInModelPath(s string) bool {
_, err := os.Stat(filepath.Join(ml.modelPath, s)) _, err := os.Stat(filepath.Join(ml.ModelPath, s))
return err == nil return err == nil
} }
func (ml *ModelLoader) ListModels() ([]string, error) { func (ml *ModelLoader) ListModels() ([]string, error) {
files, err := ioutil.ReadDir(ml.modelPath) files, err := ioutil.ReadDir(ml.ModelPath)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -70,7 +70,19 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string,
m, ok := ml.promptsTemplates[modelName] m, ok := ml.promptsTemplates[modelName]
if !ok { if !ok {
return "", fmt.Errorf("no prompt template available") modelFile := filepath.Join(ml.ModelPath, modelName)
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return "", err
}
t, exists := ml.promptsTemplates[modelName]
if exists {
m = t
}
}
if m == nil {
return "", nil
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -88,14 +100,14 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
} }
// Check if the model path exists // Check if the model path exists
// skip any error here - we run anyway if a template is not exist // skip any error here - we run anyway if a template does not exist
modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName) modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName)
if !ml.ExistsInModelPath(modelTemplateFile) { if !ml.ExistsInModelPath(modelTemplateFile) {
return nil return nil
} }
dat, err := os.ReadFile(filepath.Join(ml.modelPath, modelTemplateFile)) dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
if err != nil { if err != nil {
return err return err
} }
@ -125,7 +137,7 @@ func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, erro
} }
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.NewStableLM(modelFile) model, err := gpt2.NewStableLM(modelFile)
@ -164,7 +176,7 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
} }
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.New(modelFile) model, err := gpt2.New(modelFile)
@ -207,7 +219,7 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
} }
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gptj.New(modelFile) model, err := gptj.New(modelFile)
@ -256,7 +268,7 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
} }
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := llama.New(modelFile, opts...) model, err := llama.New(modelFile, opts...)

1
tests/fixtures/completion.tmpl vendored Normal file
View File

@ -0,0 +1 @@
{{.Input}}

28
tests/fixtures/config.yaml vendored Normal file
View File

@ -0,0 +1,28 @@
- name: list1
parameters:
model: testmodel
context_size: 512
threads: 10
stopwords:
- "HUMAN:"
- "### Response:"
roles:
user: "HUMAN:"
system: "GPT:"
template:
completion: completion
chat: ggml-gpt4all-j
- name: list2
parameters:
model: testmodel
context_size: 512
threads: 10
stopwords:
- "HUMAN:"
- "### Response:"
roles:
user: "HUMAN:"
system: "GPT:"
template:
completion: completion
chat: ggml-gpt4all-j

4
tests/fixtures/ggml-gpt4all-j.tmpl vendored Normal file
View File

@ -0,0 +1,4 @@
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
{{.Input}}
### Response:

14
tests/fixtures/gpt4.yaml vendored Normal file
View File

@ -0,0 +1,14 @@
name: gpt4all
parameters:
model: testmodel
context_size: 512
threads: 10
stopwords:
- "HUMAN:"
- "### Response:"
roles:
user: "HUMAN:"
system: "GPT:"
template:
completion: completion
chat: ggml-gpt4all-j

14
tests/fixtures/gpt4_2.yaml vendored Normal file
View File

@ -0,0 +1,14 @@
name: gpt4all-2
parameters:
model: testmodel
context_size: 1024
threads: 5
stopwords:
- "HUMAN:"
- "### Response:"
roles:
user: "HUMAN:"
system: "GPT:"
template:
completion: completion
chat: ggml-gpt4all-j