diff --git a/Makefile b/Makefile index 87337d64..263fe413 100644 --- a/Makefile +++ b/Makefile @@ -65,6 +65,7 @@ gpt4all: @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} + + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/regex_escape/gpt4allregex_escape/g' {} + mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h ## BERT embeddings diff --git a/api/gallery.go b/api/gallery.go index cb165f84..fac48622 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -5,6 +5,8 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" + "strings" "sync" "github.com/go-skynet/LocalAI/pkg/gallery" @@ -63,8 +65,15 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { updateError := func(e error) { g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) } + + url, err := op.req.DecodeURL() + if err != nil { + updateError(err) + continue + } + // Send a GET request to the URL - response, err := http.Get(op.req.URL) + response, err := http.Get(url) if err != nil { updateError(err) continue @@ -113,6 +122,43 @@ type ApplyGalleryModelRequest struct { AdditionalFiles []gallery.File `json:"files"` } +const ( + githubURI = "github:" +) + +func (request ApplyGalleryModelRequest) DecodeURL() (string, error) { + input := request.URL + var rawURL string + + if strings.HasPrefix(input, githubURI) { + parts := strings.Split(input, ":") + repoParts := strings.Split(parts[1], "@") + branch := "main" + + if len(repoParts) > 1 { + branch = repoParts[1] + } + + repoPath := strings.Split(repoParts[0], "/") + org := repoPath[0] + project := repoPath[1] + projectPath := strings.Join(repoPath[2:], "/") + + rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) + } else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { + // Handle regular URLs + u, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("invalid URL: %w", err) + } + rawURL = u.String() + } else { + return "", fmt.Errorf("invalid URL format") + } + + return rawURL, nil +} + func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { diff --git a/api/gallery_test.go b/api/gallery_test.go new file mode 100644 index 00000000..1c92c0d5 --- /dev/null +++ b/api/gallery_test.go @@ -0,0 +1,30 @@ +package api_test + +import ( + . "github.com/go-skynet/LocalAI/api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Gallery API tests", func() { + Context("requests", func() { + It("parses github with a branch", func() { + req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} + str, err := req.DecodeURL() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + }) + It("parses github without a branch", func() { + req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"} + str, err := req.DecodeURL() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + }) + It("parses URLS", func() { + req := ApplyGalleryModelRequest{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"} + str, err := req.DecodeURL() + Expect(err).ToNot(HaveOccurred()) + Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + }) + }) +})