From 474cc403ff58f9cd80bf41ddb73d376c0339fdc2 Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Mon, 26 Aug 2024 16:42:51 +0800 Subject: [PATCH] feat: implemented neuri jsonschema helpers Signed-off-by: Neko Ayaka --- .vscode/launch.json | 2 +- README.md | 4 +- go.mod | 3 + go.sum | 2 + pkg/neuri/formats/jsonschema/jsonschema.go | 538 ++++++++++++++++++ .../formats/jsonschema/jsonschema_test.go | 493 ++++++++++++++++ 6 files changed, 1040 insertions(+), 2 deletions(-) create mode 100644 pkg/neuri/formats/jsonschema/jsonschema.go create mode 100644 pkg/neuri/formats/jsonschema/jsonschema_test.go diff --git a/.vscode/launch.json b/.vscode/launch.json index 0c48e1d..34b7fbc 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "program": "${workspaceFolder}/cmd/lingticio/api-server/main.go", "args": [ "-e", - "${workspaceFolder}/cmd/lingticio/api-server/.env" + "${workspaceFolder}/cmd/lingticio/api-server/.env", ] } ] diff --git a/README.md b/README.md index ffc832f..14d2de8 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,9 @@ - [ ] Intelli-routing - [ ] Load balancing - [ ] Semantic cache -- [ ] Structured data +- [x] Structured data + - [x] Powered by Neuri + - [ ] Support any order of structured data - [ ] Function calling - [ ] Instruction mapping - [ ] Generative streaming diff --git a/go.mod b/go.mod index 3a29baf..7176451 100644 --- a/go.mod +++ b/go.mod @@ -19,9 +19,11 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rivo/uniseg v0.4.7 github.com/samber/lo v1.46.0 + github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sashabaranov/go-openai v1.29.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 github.com/vektah/gqlparser/v2 v2.5.16 go.uber.org/fx v1.22.0 go.uber.org/zap v1.27.0 @@ -57,6 +59,7 @@ require ( github.com/onsi/gomega v1.31.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect diff --git a/go.sum b/go.sum index dd7b5b5..1acbc63 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6g github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/samber/lo v1.46.0 h1:w8G+oaCPgz1PoCJztqymCFaKwXt+5cCXn51uPxExFfQ= github.com/samber/lo v1.46.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= +github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= +github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY= github.com/sashabaranov/go-openai v1.29.0 h1:eBH6LSjtX4md5ImDCX8hNhHQvaRf22zujiERoQpsvLo= github.com/sashabaranov/go-openai v1.29.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sclevine/spec v1.4.0 h1:z/Q9idDcay5m5irkZ28M7PtQM4aOISzOpj4bUPkDee8= diff --git a/pkg/neuri/formats/jsonschema/jsonschema.go b/pkg/neuri/formats/jsonschema/jsonschema.go new file mode 100644 index 0000000..8171d80 --- /dev/null +++ b/pkg/neuri/formats/jsonschema/jsonschema.go @@ -0,0 +1,538 @@ +package jsonschema + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" + + "github.com/samber/lo" + jsonschema "github.com/santhosh-tekuri/jsonschema/v5" +) + +const ( + STRING_INNER = `[^"\\]*(?:\\.[^"\\]*)*` + STRING = `"` + STRING_INNER + `"` + INTEGER = `(-)?(0|[1-9][0-9]*)` + NUMBER = INTEGER + `(\.[0-9]+)?([eE][+-]?[0-9]+)?` + BOOLEAN = `(true|false)` + NULL = `null` + WHITESPACE = `\s*` + DATE_TIME = `"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"` + DATE = `"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"` + TIME = `"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"` + UUID = `"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"` +) + +var typeToRegex = map[string]string{ + "string": STRING, + "integer": INTEGER, + "number": NUMBER, + "boolean": BOOLEAN, + "null": NULL, +} + +var formatToRegex = map[string]string{ + "uuid": UUID, + "date-time": DATE_TIME, + "date": DATE, + "time": TIME, +} + +func BuildRegexFromSchemaString(schema string, whitespacePattern string) (string, error) { + s, err := jsonschema.CompileString("schema.json", schema) + if err != nil { + return "", err + } + + regexp, err := BuildRegexFromSchema(s, whitespacePattern) + if err != nil { + return "", err + } + + return regexp, nil +} + +func BuildRegexFromSchema(schema *jsonschema.Schema, whitespacePattern string) (string, error) { + if whitespacePattern == "" { + whitespacePattern = WHITESPACE + } + + innerRegex, err := toRegex(schema, whitespacePattern, schema) + if err != nil { + return "", err + } + + return fmt.Sprintf("%s%s%s", WHITESPACE, innerRegex, WHITESPACE), nil +} + +func toRegex(instance *jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + switch { + case instance.Properties != nil: + return handleProperties(instance.Properties, instance, whitespacePattern, rootSchema) + case len(instance.AllOf) > 0: + return handleAllOf(instance.AllOf, whitespacePattern, rootSchema) + case len(instance.AnyOf) > 0: + return handleAnyOf(instance.AnyOf, whitespacePattern, rootSchema) + case len(instance.OneOf) > 0: + return handleOneOf(instance.OneOf, whitespacePattern, rootSchema) + case instance.PrefixItems != nil: + return handlePrefixItems(instance.PrefixItems, instance, whitespacePattern, rootSchema) + case instance.Enum != nil: + return handleEnum(instance.Enum) + case instance.Constant != nil: + return handleConst(instance.Constant) + case instance.Ref != nil: + // return handleRef(instance.Ref, rootSchema, whitespacePattern) + case len(instance.Types) > 0: + return handleType(instance, whitespacePattern, rootSchema) + case len(instance.Types) == 0: + return handleEmptySchema(whitespacePattern, rootSchema) + } + + return "", fmt.Errorf("unsupported schema type") +} + +func handleEmptySchema(whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + types := []string{"boolean", "null", "number", "integer", "string", "array", "object"} + regExps := make([]string, len(types)) + + for i, t := range types { + schema := &jsonschema.Schema{Types: []string{t}} + + regex, err := toRegex(schema, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + + regExps[i] = fmt.Sprintf("(%s)", regex) + } + + return strings.Join(regExps, "|"), nil +} + +func handleProperties(properties map[string]*jsonschema.Schema, instance *jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + if len(properties) == 0 { + return `\{\s*\}`, nil + } + + // Get all property names and sort them + propertyNames := make([]string, 0, len(properties)) + for name := range properties { + propertyNames = append(propertyNames, name) + } + + sort.Strings(propertyNames) + + propertyRegExps := make([]string, 0, len(propertyNames)) + isRequired := make(map[string]bool) + + for _, rp := range instance.Required { + isRequired[rp] = true + } + + for _, name := range propertyNames { + value := properties[name] + subRegex := fmt.Sprintf(`%s"%s"%s:%s`, whitespacePattern, regexp.QuoteMeta(name), whitespacePattern, whitespacePattern) + + valueRegex, err := toRegex(value, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + + subRegex += valueRegex + + if isRequired[name] { + propertyRegExps = append(propertyRegExps, subRegex) + } else { + propertyRegExps = append(propertyRegExps, fmt.Sprintf("(%s)?", subRegex)) + } + } + + propertiesRegex := strings.Join(propertyRegExps, fmt.Sprintf(`%s,?%s`, whitespacePattern, whitespacePattern)) + objectRegex := fmt.Sprintf(`\{%s%s%s\}`, whitespacePattern, propertiesRegex, whitespacePattern) + + if instance.AdditionalProperties != nil { + if additionalProps, ok := instance.AdditionalProperties.(bool); ok && additionalProps { + // If additional properties are allowed, add a pattern for them + objectRegex = fmt.Sprintf(`\{%s(%s%s,?%s)*(%s)?%s\}`, + whitespacePattern, + propertiesRegex, + whitespacePattern, + whitespacePattern, + `"[^"]+"\s*:\s*[^,}]+`, + whitespacePattern) + } + } + + return objectRegex, nil +} + +func handleOneOf(oneOf []*jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + subRegExps := make([]string, len(oneOf)) + + for i, schema := range oneOf { + subRegex, err := toRegex(schema, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + subRegExps[i] = fmt.Sprintf("(?:%s)", subRegex) + } + + return fmt.Sprintf("(%s)", strings.Join(subRegExps, "|")), nil +} + +func handleAllOf(allOf []*jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + mergedProperties := make(map[string]*jsonschema.Schema) + requiredProperties := make([]string, 0) + + for _, schema := range allOf { + for propName, propSchema := range schema.Properties { + mergedProperties[propName] = propSchema + } + requiredProperties = append(requiredProperties, schema.Required...) + } + + regex := "\\{" + propertyRegExps := make([]string, 0) + + for key, value := range mergedProperties { + propertyRegex, err := toRegex(value, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + + propPattern := fmt.Sprintf(`%s"%s"%s:%s%s`, whitespacePattern, key, whitespacePattern, whitespacePattern, propertyRegex) + if !lo.Contains(requiredProperties, key) { + propPattern = fmt.Sprintf("(%s)?", propPattern) + } + + propertyRegExps = append(propertyRegExps, propPattern) + } + + regex += strings.Join(propertyRegExps, fmt.Sprintf("%s,", whitespacePattern)) + regex += fmt.Sprintf("%s\\}", whitespacePattern) + + return regex, nil +} + +func handleAnyOf(anyOf []*jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + subRegExps := make([]string, len(anyOf)) + var err error + + for i, schema := range anyOf { + subRegExps[i], err = toRegex(schema, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + } + + // Generate all possible combinations + combinations := []string{} + + for i := 1; i <= len(subRegExps); i++ { + combos := getCombinations(subRegExps, i) + for _, combo := range combos { + combinations = append(combinations, combineSchemas(combo, whitespacePattern)) + } + } + + return fmt.Sprintf("(%s)", strings.Join(combinations, "|")), nil +} + +func combineSchemas(schemas []string, whitespacePattern string) string { + // Remove outer curly braces from each schema + for i, schema := range schemas { + schemas[i] = strings.TrimPrefix(strings.TrimSuffix(schema, `\s*\}`), `\{\s*`) + } + + // Join schemas with optional comma and whitespace + combined := strings.Join(schemas, fmt.Sprintf(`%s,?%s`, whitespacePattern, whitespacePattern)) + + // Add back the outer curly braces + return fmt.Sprintf(`\{%s%s%s\}`, whitespacePattern, combined, whitespacePattern) +} + +func getCombinations(arr []string, k int) [][]string { + if k == 1 { + result := make([][]string, len(arr)) + for i, v := range arr { + result[i] = []string{v} + } + + return result + } + + result := [][]string{} + + for i := 0; i <= len(arr)-k; i++ { + subCombos := getCombinations(arr[i+1:], k-1) + for _, subCombo := range subCombos { + result = append(result, append([]string{arr[i]}, subCombo...)) + } + } + + return result +} + +func handlePrefixItems(prefixItems []*jsonschema.Schema, instance *jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + elementPatterns := make([]string, len(prefixItems)) + + for i, item := range prefixItems { + pattern, err := toRegex(item, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + elementPatterns[i] = pattern + } + + commaSplitPattern := fmt.Sprintf("%s,%s", whitespacePattern, whitespacePattern) + tupleInner := strings.Join(elementPatterns, commaSplitPattern) + regex := fmt.Sprintf("\\[%s%s", whitespacePattern, tupleInner) + + if items, ok := instance.Items.(*jsonschema.Schema); ok { + additionalItemsRegex, err := toRegex(items, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + regex += fmt.Sprintf("(%s%s)*", commaSplitPattern, additionalItemsRegex) + } + + regex += fmt.Sprintf("%s\\]", whitespacePattern) + + return regex, nil +} + +func handleEnum(enum []interface{}) (string, error) { + choices := make([]string, len(enum)) + + for i, choice := range enum { + var stringified string + + switch v := choice.(type) { + case string: + // For strings, use JSON marshaling to ensure proper quoting + bytes, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("failed to marshal string enum value: %w", err) + } + stringified = string(bytes) + case float64, int, int64, float32: + // For numbers, use fmt.Sprintf without quotes + stringified = fmt.Sprintf("%v", v) + case bool: + // For booleans, use strings.ToLower to ensure "true" or "false" + stringified = strings.ToLower(fmt.Sprintf("%v", v)) + case nil: + stringified = "null" + case json.Number: + stringified = v.String() + default: + return "", fmt.Errorf("unsupported data type in enum: %T", v) + } + // Escape special regex characters + choices[i] = regexp.QuoteMeta(stringified) + } + + return fmt.Sprintf("(%s)", strings.Join(choices, "|")), nil +} + +func handleConst(constValue interface{}) (string, error) { + return regexp.QuoteMeta(fmt.Sprintf("%v", constValue)), nil +} + +// func handleRef(ref string, rootSchema *jsonschema.Schema, whitespacePattern string) (string, error) { +// if strings.HasPrefix(ref, "#/") { +// refSchema, err := rootSchema.CompileRef(ref) +// if err != nil { +// return "", err +// } +// return toRegex(refSchema, whitespacePattern, rootSchema) +// } +// return "", fmt.Errorf("external references are not supported") +// } + +func handleType(instance *jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + switch { + case lo.Contains(instance.Types, "string"): + return handleStringType(instance, whitespacePattern) + case lo.Contains(instance.Types, "number"), lo.Contains(instance.Types, "integer"): + return handleNumberType(instance) + case lo.Contains(instance.Types, "array"): + return handleArrayType(instance, whitespacePattern, rootSchema) + case lo.Contains(instance.Types, "object"): + return handleObjectType(instance, whitespacePattern, rootSchema) + case lo.Contains(instance.Types, "boolean"): + return typeToRegex["boolean"], nil + case lo.Contains(instance.Types, "null"): + return typeToRegex["null"], nil + case len(instance.Types) > 1: + return handleMultipleTypes(lo.Map(instance.Types, func(item string, _ int) any { + return item + }), whitespacePattern, rootSchema) + default: + return "", fmt.Errorf("invalid type specification") + } +} + +func handleStringType(instance *jsonschema.Schema, _ string) (string, error) { + if instance.MaxLength > 0 { + minLength := 0 + if instance.MinLength > 0 { + minLength = instance.MinLength + } + + return fmt.Sprintf(`"%s{%d,%d}"`, STRING_INNER, minLength, instance.MaxLength), nil + } else if instance.MinLength > 0 { + return fmt.Sprintf(`"%s{%d,}"`, STRING_INNER, instance.MinLength), nil + } else if instance.Pattern != nil { + pattern := instance.Pattern.String() + if len(pattern) >= 2 && pattern[0] == '^' && pattern[len(pattern)-1] == '$' { + return fmt.Sprintf(`("%s")`, pattern[1:len(pattern)-1]), nil + } + + return fmt.Sprintf(`("%s")`, pattern), nil + } else if instance.Format != "" { + if regex, ok := formatToRegex[instance.Format]; ok { + return regex, nil + } + + return "", fmt.Errorf("format %s is not supported", instance.Format) + } + + // Default case: any string + return STRING, nil +} +func handleNumberType(instance *jsonschema.Schema) (string, error) { + if lo.Contains(instance.Types, "integer") { + return typeToRegex["integer"], nil + } + + return typeToRegex["number"], nil +} + +func handleArrayType(instance *jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + if instance.Items2020 == nil { + return `\[\s*([^,\]]*\s*,?\s*)*\s*\]`, nil + } + + itemsRegex, err := toRegex(instance.Items2020, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + + if instance.MaxItems > 0 { + minItems := 0 + if instance.MinItems != -1 { + minItems = instance.MinItems + } + + return fmt.Sprintf(`\[\s*(%s\s*,?\s*){%d,%d}\s*\]`, itemsRegex, minItems, instance.MaxItems), nil + } else if instance.MinItems != -1 { + return fmt.Sprintf(`\[\s*(%s\s*,?\s*){%d,}\s*\]`, itemsRegex, instance.MinItems), nil + } + + return fmt.Sprintf(`\[\s*(%s\s*,?\s*){%d,}\s*\]`, itemsRegex, 0), nil +} + +func handleObjectType(instance *jsonschema.Schema, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + if len(instance.Properties) == 0 && instance.AdditionalProperties == nil { + return `\{\s*\}`, nil + } + + // Get all property names and sort them + propertyNames := make([]string, 0, len(instance.Properties)) + for name := range instance.Properties { + propertyNames = append(propertyNames, name) + } + + sort.Strings(propertyNames) + + propertyRegexes := make([]string, 0, len(propertyNames)) + + for _, name := range propertyNames { + schema := instance.Properties[name] + + propertyRegex, err := toRegex(schema, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + + propPattern := fmt.Sprintf(`"%s"\s*:\s*%s`, regexp.QuoteMeta(name), propertyRegex) + if !lo.Contains(instance.Required, name) { + propPattern = fmt.Sprintf(`(%s)?`, propPattern) + } + + propertyRegexes = append(propertyRegexes, propPattern) + } + + propertiesRegex := strings.Join(propertyRegexes, `\s*,?\s*`) + objectRegex := fmt.Sprintf(`\{\s*%s\s*\}`, propertiesRegex) + + if instance.AdditionalProperties != nil { + if additionalProps, ok := instance.AdditionalProperties.(bool); ok && additionalProps { + // If additional properties are allowed, add a pattern for them + objectRegex = fmt.Sprintf(`\{\s*(%s\s*,?\s*)*(%s)?\s*\}`, propertiesRegex, `"[^"]+"\s*:\s*[^,}]+`) + } + } + + return objectRegex, nil +} + +func handleMultipleTypes(types []interface{}, whitespacePattern string, rootSchema *jsonschema.Schema) (string, error) { + typesStr := lo.Map(types, func(t interface{}, _ int) string { + str, _ := t.(string) + return str + }) + + regExps := make([]string, 0) + + for _, t := range typesStr { + schema := &jsonschema.Schema{Types: []string{t}} + + regex, err := toRegex(schema, whitespacePattern, rootSchema) + if err != nil { + return "", err + } + + regExps = append(regExps, regex) + } + + return fmt.Sprintf("(%s)", strings.Join(regExps, "|")), nil +} + +func ExtractBySchema(schema *jsonschema.Schema, extractFrom string) (string, error) { + regex, err := BuildRegexFromSchema(schema, "") + if err != nil { + return "", err + } + + re, err := regexp.Compile(regex) + if err != nil { + return "", err + } + + match := re.FindString(extractFrom) + if match == "" { + return "", fmt.Errorf("no match found") + } + + return strings.TrimSpace(match), nil +} + +func ExtractStructBySchema[T any](schema *jsonschema.Schema, extractFrom string) (*T, error) { + extracted, err := ExtractBySchema(schema, extractFrom) + if err != nil { + return nil, err + } + + var result T + + err = json.Unmarshal([]byte(extracted), &result) + if err != nil { + return nil, err + } + + return &result, nil +} diff --git a/pkg/neuri/formats/jsonschema/jsonschema_test.go b/pkg/neuri/formats/jsonschema/jsonschema_test.go new file mode 100644 index 0000000..c2afc70 --- /dev/null +++ b/pkg/neuri/formats/jsonschema/jsonschema_test.go @@ -0,0 +1,493 @@ +package jsonschema + +import ( + "encoding/json" + "fmt" + "regexp" + "testing" + + "github.com/samber/lo" + jsonschema "github.com/santhosh-tekuri/jsonschema/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func toJSONString(v any) string { + return string(lo.Must(json.Marshal(v))) +} + +func createDistortions(str string, primitiveType bool) []string { + distortions := []string{ + fmt.Sprintf("Sure, here's the JSON: %s", str), + fmt.Sprintf("The JSON object is: \n```json\n%s\n```", str), + fmt.Sprintf("Here's what you asked for:\n%s\nIs there anything else?", str), + fmt.Sprintf("[%s]", str), + } + if !primitiveType { + distortions = append(distortions, fmt.Sprintf(`{"result": %s}`, str)) + } + + return distortions +} + +func createInvalidDistortions(str string, primitiveType bool) []string { + distortions := []string{ + fmt.Sprintf("This is invalid: %s", str), + fmt.Sprintf("```json\n%s\n```", str), + } + if !primitiveType { + distortions = append(distortions, fmt.Sprintf(`{"invalid": %s}`, str)) + } + + return distortions +} + +func TestBuildRegexFromSchemaString(t *testing.T) { + t.Run("handles basic object schema", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "number"}, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"age":30,"name":"John"}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"age":30,"name":123}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles arrays", func(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{"type": "number"}, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `[1,2,3]` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `["abcd","abcd","abcd"]` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles string formats", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "string", "format": "uuid"}, + "date": map[string]any{"type": "string", "format": "date-time"}, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"date":"2023-06-13T15:30:00Z","id":"123e4567-e89b-12d3-a456-426614174000"}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"date":"2023-06-13","id":"not-a-uuid"}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles number constraints", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "integer": map[string]any{"type": "integer", "minimum": 0, "maximum": 100}, + "float": map[string]any{"type": "number", "exclusiveMinimum": 0, "exclusiveMaximum": 1}, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"float":0.5,"integer":50}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"float":"0.5","integer":"50"}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles required properties", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "number"}, + "name": map[string]any{"type": "string"}, + }, + "required": []string{"id"}, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"id":1,"name":"John"}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"name":"John"}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles string enums", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"color":"red"}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"color":"yellow"}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles number enums", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{"type": "string", "enum": []int{1, 2, 3}}, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"status":1}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"status":10}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles nested objects", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "person": map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "number"}, + }, + }, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"person":{"age":30,"name":"John"}}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"person":{"age":"30","name":"John"}}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles oneOf", func(t *testing.T) { + schema := map[string]any{ + "oneOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{"value": map[string]any{"type": "string"}}, + }, + map[string]any{ + "type": "object", + "properties": map[string]any{"value": map[string]any{"type": "number"}}, + }, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSONs := []string{`{"value":"text"}`, `{"value":1}`} + for _, validJSON := range validJSONs { + for _, distorted := range createDistortions(validJSON, true) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + } + + invalidJSON := `true` + for _, distorted := range createInvalidDistortions(invalidJSON, true) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) + + t.Run("handles allOf", func(t *testing.T) { + schema := map[string]any{ + "allOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{"a": map[string]any{"type": "number"}}, + "required": []string{"a"}, + }, + map[string]any{ + "type": "object", + "properties": map[string]any{"b": map[string]any{"type": "string"}}, + "required": []string{"b"}, + }, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"a":1,"b":"text"}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSONs := []string{`{"a":1}`, `{"b":"text"}`, `{"a":"1","b":"text"}`} + for _, invalidJSON := range invalidJSONs { + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + } + }) + + t.Run("handles anyOf", func(t *testing.T) { + schema := map[string]any{ + "anyOf": []any{ + map[string]any{ + "type": "object", + "properties": map[string]any{"a": map[string]any{"type": "number"}}, + "required": []string{"a"}, + }, + map[string]any{ + "type": "object", + "properties": map[string]any{"b": map[string]any{"type": "string"}}, + "required": []string{"b"}, + }, + }, + } + + schemaJSON, err := json.Marshal(schema) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchemaString(string(schemaJSON), "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSONs := []string{`{"a":1}`, `{"b":"text"}`, `{"a":1,"b":"text"}`} + for _, validJSON := range validJSONs { + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + } + + invalidJSONs := []string{`{"a":"1"}`, `{"b":2}`, `{}`} + for _, invalidJSON := range invalidJSONs { + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + } + }) +} + +func TestBuildRegexFromSchema(t *testing.T) { + t.Run("handles basic object schema", func(t *testing.T) { + schemaMap := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "number"}, + }, + } + + schema, err := jsonschema.CompileString("schema.json", toJSONString(schemaMap)) + require.NoError(t, err) + + regexString, err := BuildRegexFromSchema(schema, "") + require.NoError(t, err) + + regex := regexp.MustCompile(regexString) + + validJSON := `{"age":30,"name":"John"}` + for _, distorted := range createDistortions(validJSON, false) { + match := regex.FindString(distorted) + assert.NotEmpty(t, match) + assert.Contains(t, match, validJSON) + } + + invalidJSON := `{"age":30,"name":123}` + for _, distorted := range createInvalidDistortions(invalidJSON, false) { + match := regex.FindString(distorted) + assert.Empty(t, match) + } + }) +} + +func TestExtractBySchema(t *testing.T) { + t.Run("extracts object properties", func(t *testing.T) { + schemaMap := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "number"}, + }, + } + + schema, err := jsonschema.CompileString("schema.json", toJSONString(schemaMap)) + require.NoError(t, err) + + input := `Sure, here's the JSON: {"age":30,"name":"John"}` + extracted, err := ExtractBySchema(schema, input) + require.NoError(t, err) + assert.Equal(t, `{"age":30,"name":"John"}`, extracted) + }) +} + +func TestExtractObjectBySchema(t *testing.T) { + t.Run("extracts object properties", func(t *testing.T) { + schemaMap := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "number"}, + }, + } + + schema, err := jsonschema.CompileString("schema.json", toJSONString(schemaMap)) + require.NoError(t, err) + + input := "Sure, here's the JSON: ```json{\"age\":30,\"name\":\"John\"}```" + + type Person struct { + Name string `json:"name"` + Age int `json:"age"` + } + + extracted, err := ExtractStructBySchema[Person](schema, input) + require.NoError(t, err) + assert.Equal(t, Person{Name: "John", Age: 30}, *extracted) + }) +}