mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat(image): support response_type
in the OpenAI API request (#2347)
* Change response_format type to string to match OpenAI Spec Signed-off-by: prajwal <prajwalnayak7@gmail.com> * updated response_type type to interface Signed-off-by: prajwal <prajwalnayak7@gmail.com> * feat: correctly parse generic struct Signed-off-by: mudler <mudler@localai.io> * add tests Signed-off-by: mudler <mudler@localai.io> --------- Signed-off-by: prajwal <prajwalnayak7@gmail.com> Signed-off-by: mudler <mudler@localai.io> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com> Co-authored-by: mudler <mudler@localai.io>
This commit is contained in:
parent
087bceccac
commit
4d98dd9ce7
@ -27,9 +27,11 @@ type BackendConfig struct {
|
|||||||
Backend string `yaml:"backend"`
|
Backend string `yaml:"backend"`
|
||||||
TemplateConfig TemplateConfig `yaml:"template"`
|
TemplateConfig TemplateConfig `yaml:"template"`
|
||||||
|
|
||||||
PromptStrings, InputStrings []string `yaml:"-"`
|
PromptStrings, InputStrings []string `yaml:"-"`
|
||||||
InputToken [][]int `yaml:"-"`
|
InputToken [][]int `yaml:"-"`
|
||||||
functionCallString, functionCallNameString string `yaml:"-"`
|
functionCallString, functionCallNameString string `yaml:"-"`
|
||||||
|
ResponseFormat string `yaml:"-"`
|
||||||
|
ResponseFormatMap map[string]interface{} `yaml:"-"`
|
||||||
|
|
||||||
FunctionsConfig functions.FunctionsConfig `yaml:"function"`
|
FunctionsConfig functions.FunctionsConfig `yaml:"function"`
|
||||||
|
|
||||||
|
@ -183,8 +183,13 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
|
|||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
if config.ResponseFormatMap != nil {
|
||||||
input.Grammar = functions.JSONBNF
|
d := schema.ChatCompletionResponseFormat{}
|
||||||
|
dat, _ := json.Marshal(config.ResponseFormatMap)
|
||||||
|
_ = json.Unmarshal(dat, &d)
|
||||||
|
if d.Type == "json_object" {
|
||||||
|
input.Grammar = functions.JSONBNF
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Grammar = input.Grammar
|
config.Grammar = input.Grammar
|
||||||
|
@ -69,8 +69,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||||||
return fmt.Errorf("failed reading parameters from request:%w", err)
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.ResponseFormat.Type == "json_object" {
|
if config.ResponseFormatMap != nil {
|
||||||
input.Grammar = functions.JSONBNF
|
d := schema.ChatCompletionResponseFormat{}
|
||||||
|
dat, _ := json.Marshal(config.ResponseFormatMap)
|
||||||
|
_ = json.Unmarshal(dat, &d)
|
||||||
|
if d.Type == "json_object" {
|
||||||
|
input.Grammar = functions.JSONBNF
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Grammar = input.Grammar
|
config.Grammar = input.Grammar
|
||||||
|
@ -149,10 +149,8 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
|||||||
return fmt.Errorf("invalid value for 'size'")
|
return fmt.Errorf("invalid value for 'size'")
|
||||||
}
|
}
|
||||||
|
|
||||||
b64JSON := false
|
b64JSON := config.ResponseFormat == "b64_json"
|
||||||
if input.ResponseFormat.Type == "b64_json" {
|
|
||||||
b64JSON = true
|
|
||||||
}
|
|
||||||
// src and clip_skip
|
// src and clip_skip
|
||||||
var result []schema.Item
|
var result []schema.Item
|
||||||
for _, i := range config.PromptStrings {
|
for _, i := range config.PromptStrings {
|
||||||
|
@ -129,6 +129,15 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
|
|||||||
config.Maxtokens = input.Maxtokens
|
config.Maxtokens = input.Maxtokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.ResponseFormat != nil {
|
||||||
|
switch responseFormat := input.ResponseFormat.(type) {
|
||||||
|
case string:
|
||||||
|
config.ResponseFormat = responseFormat
|
||||||
|
case map[string]interface{}:
|
||||||
|
config.ResponseFormatMap = responseFormat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch stop := input.Stop.(type) {
|
switch stop := input.Stop.(type) {
|
||||||
case string:
|
case string:
|
||||||
if stop != "" {
|
if stop != "" {
|
||||||
|
@ -99,6 +99,8 @@ type OpenAIModel struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImageGenerationResponseFormat string
|
||||||
|
|
||||||
type ChatCompletionResponseFormatType string
|
type ChatCompletionResponseFormatType string
|
||||||
|
|
||||||
type ChatCompletionResponseFormat struct {
|
type ChatCompletionResponseFormat struct {
|
||||||
@ -114,7 +116,7 @@ type OpenAIRequest struct {
|
|||||||
// whisper
|
// whisper
|
||||||
File string `json:"file" validate:"required"`
|
File string `json:"file" validate:"required"`
|
||||||
//whisper/image
|
//whisper/image
|
||||||
ResponseFormat ChatCompletionResponseFormat `json:"response_format"`
|
ResponseFormat interface{} `json:"response_format,omitempty"`
|
||||||
// image
|
// image
|
||||||
Size string `json:"size"`
|
Size string `json:"size"`
|
||||||
// Prompt is read only by completion/image API calls
|
// Prompt is read only by completion/image API calls
|
||||||
|
@ -123,13 +123,36 @@ var _ = Describe("E2E test", func() {
|
|||||||
openai.ImageRequest{
|
openai.ImageRequest{
|
||||||
Prompt: "test",
|
Prompt: "test",
|
||||||
Size: openai.CreateImageSize512x512,
|
Size: openai.CreateImageSize512x512,
|
||||||
//ResponseFormat: openai.CreateImageResponseFormatURL,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||||
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
||||||
})
|
})
|
||||||
|
It("correctly changes the response format to url", func() {
|
||||||
|
resp, err := client.CreateImage(context.TODO(),
|
||||||
|
openai.ImageRequest{
|
||||||
|
Prompt: "test",
|
||||||
|
Size: openai.CreateImageSize512x512,
|
||||||
|
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||||
|
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
||||||
|
})
|
||||||
|
It("correctly changes the response format to base64", func() {
|
||||||
|
resp, err := client.CreateImage(context.TODO(),
|
||||||
|
openai.ImageRequest{
|
||||||
|
Prompt: "test",
|
||||||
|
Size: openai.CreateImageSize512x512,
|
||||||
|
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||||
|
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
Context("embeddings", func() {
|
Context("embeddings", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
|
Loading…
Reference in New Issue
Block a user