Skip to content

Commit

Permalink
DuckDB sql parser (#2745)
Browse files Browse the repository at this point in the history
* Basic AST traversal

* Adding basic table rewrite

* Adding basic limit rewrite

* Use native json and support CTE and Unions

* Adding basic column extraction

* Adding basic annotion parser

* Sources sql prototype

* Adding column expression parsing

* Addressing PR comments

* Revert prototype code
  • Loading branch information
AdityaHegde committed Jul 12, 2023
1 parent be5ce9f commit b6f4ee3
Show file tree
Hide file tree
Showing 5 changed files with 1,111 additions and 42 deletions.
115 changes: 115 additions & 0 deletions runtime/pkg/duckdbsql/ast_values.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package duckdbsql

import "strings"

// TODO: figure out a way to cast map[string]interface{} returned by json unmarshal to map[astNodeKey]interface{} and replace string in key to astNodeKey
type astNode map[string]interface{}

const (
astKeyError string = "error"
astKeyErrorMessage string = "error_message"
astKeyStatements string = "statements"
astKeyNode string = "node"
astKeyType string = "type"
astKeyKey string = "key"
astKeyFromTable string = "from_table"
astKeySelectColumnList string = "select_list"
astKeyTableName string = "table_name"
astKeyFunction string = "function"
astKeyFunctionName string = "function_name"
astKeyChildren string = "children"
astKeyValue string = "value"
astKeyLeft string = "left"
astKeyRight string = "right"
astKeyColumnNames string = "column_names"
astKeyAlias string = "alias"
astKeyID string = "id"
astKeySample string = "sample"
astKeyColumnNameAlias string = "column_name_alias"
astKeyModifiers string = "modifiers"
astKeyLimit string = "limit"
astKeyClass string = "class"
astKeyCTE string = "cte_map"
astKeyMap string = "map"
astKeyQuery string = "query"
astKeySubQuery string = "subquery"
astKetRelationName string = "relation_name"
)

func toBoolean(a astNode, k string) bool {
v, ok := a[k]
if !ok {
return false
}
switch vt := v.(type) {
case bool:
return vt
default:
return false
}
}

func toString(a astNode, k string) string {
v, ok := a[k]
if !ok {
return ""
}
switch vt := v.(type) {
case string:
return vt
default:
return ""
}
}

func toNode(a astNode, k string) astNode {
v, ok := a[k]
if !ok {
return nil
}
switch vt := v.(type) {
case map[string]interface{}:
return vt
default:
return nil
}
}

func toArray(a astNode, k string) []interface{} {
v, ok := a[k]
if !ok {
return make([]interface{}, 0)
}
switch v.(type) {
case interface{}:
return v.([]interface{})
default:
return make([]interface{}, 0)
}
}

func toNodeArray(a astNode, k string) []astNode {
arr := toArray(a, k)
nodeArr := make([]astNode, len(arr))
for i, e := range arr {
nodeArr[i] = e.(map[string]interface{})
}
return nodeArr
}

func toTypedArray[E interface{}](a astNode, k string) []E {
arr := toArray(a, k)
typedArr := make([]E, len(arr))
for i, e := range arr {
typedArr[i] = e.(E)
}
return typedArr
}

func getColumnName(node astNode) string {
alias := toString(node, astKeyAlias)
if alias != "" {
return alias
}
return strings.Join(toTypedArray[string](node, astKeyColumnNames), ".")
}
190 changes: 152 additions & 38 deletions runtime/pkg/duckdbsql/duckdbsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,157 @@ import (
"context"
databasesql "database/sql"
"database/sql/driver"
"encoding/json"
"regexp"
"sync"

"github.com/marcboeker/go-duckdb"
)

type AST struct {
sql string
ast astNode
rootNodes []*selectNode
aliases map[string]bool
added map[string]bool
fromNodes []*fromNode
columns []*columnNode
}

type selectNode struct {
ast astNode
}

type columnNode struct {
ast astNode
ref *ColumnRef
}

type fromNode struct {
ast astNode
parent astNode
childKey string
ref *TableRef
}

func Parse(sql string) (*AST, error) {
sqlAst, err := queryString("select json_serialize_sql(?::VARCHAR)", sql)
if err != nil {
return nil, err
}

nativeAst := astNode{}
err = json.Unmarshal(sqlAst, &nativeAst)
if err != nil {
return nil, err
}

ast := &AST{
sql: sql,
ast: nativeAst,
rootNodes: make([]*selectNode, 0),
aliases: map[string]bool{},
added: map[string]bool{},
fromNodes: make([]*fromNode, 0),
columns: make([]*columnNode, 0),
}

err = ast.traverse()
if err != nil {
return nil, err
}
return ast, nil
}

// Format normalizes a DuckDB SQL statement
func Format(sql string) (string, error) {
return queryString("SELECT json_deserialize_sql(json_serialize_sql(?::VARCHAR))", sql)
func (a *AST) Format() (string, error) {
sql, err := json.Marshal(a.ast)
if err != nil {
return "", err
}
res, err := queryString("SELECT json_deserialize_sql(?::JSON)", string(sql))
return string(res), err
}

// Sanitize strips comments and normalizes a DuckDB SQL statement
func Sanitize(sql string) (string, error) {
panic("not implemented")
// RewriteTableRefs replaces table references in a DuckDB SQL query. Only replacing with a base table reference is supported right now.
func (a *AST) RewriteTableRefs(fn func(table *TableRef) (*TableRef, bool)) error {
for _, node := range a.fromNodes {
newRef, shouldReplace := fn(node.ref)
if !shouldReplace {
continue
}

// only rewriting to a base table is supported as of now.
if newRef.Name != "" {
err := node.rewriteToBaseTable(newRef.Name)
if err != nil {
return err
}
}
}

return nil
}

// RewriteLimit rewrites a DuckDB SQL statement to limit the result size
func RewriteLimit(sql string, limit, offset int) (string, error) {
panic("not implemented")
func (a *AST) RewriteLimit(limit, offset int) error {
if len(a.rootNodes) == 0 {
return nil
}

// We only need to add limit to the top level query
err := a.rootNodes[0].rewriteLimit(limit, offset)
if err != nil {
return err
}

return nil
}

// ExtractColumnRefs extracts column references from the outermost SELECT of a DuckDB SQL statement
func (a *AST) ExtractColumnRefs() []*ColumnRef {
columnRefs := make([]*ColumnRef, 0)
for _, node := range a.columns {
columnRefs = append(columnRefs, node.ref)
}
return columnRefs
}

var annotationsRegex = regexp.MustCompile(`(?m)^--[ \t]*@([a-zA-Z0-9_\-.]*)[ \t]*(?::[ \t]*(.*?))?\s*$`)

// ExtractAnnotations extracts annotations from comments prefixed with '@', and optionally a value after a ':'.
// Examples: "-- @materialize" and "-- @materialize: true".
func (a *AST) ExtractAnnotations() map[string]*Annotation {
annotations := map[string]*Annotation{}
subMatches := annotationsRegex.FindAllStringSubmatch(a.sql, -1)
for _, subMatch := range subMatches {
an := &Annotation{
Key: subMatch[1],
}
if len(subMatch) > 2 {
an.Value = subMatch[2]
}
annotations[an.Key] = an
}
return annotations
}

func (a *AST) newFromNode(node, parent astNode, childKey string, ref *TableRef) {
fn := &fromNode{
ast: node,
parent: parent,
childKey: childKey,
ref: ref,
}
a.fromNodes = append(a.fromNodes, fn)
}

func (a *AST) newColumnNode(node astNode, ref *ColumnRef) {
cn := &columnNode{
ast: node,
ref: ref,
}
a.columns = append(a.columns, cn)
}

// TableRef has information extracted about a DuckDB table or table function reference
Expand All @@ -30,16 +163,7 @@ type TableRef struct {
Function string
Path string
Properties map[string]any
}

// ExtractTableRefs extracts table references from a DuckDB SQL query
func ExtractTableRefs(sql string) ([]*TableRef, error) {
panic("not implemented")
}

// RewriteTableRefs replaces table references in a DuckDB SQL query
func RewriteTableRefs(sql string, fn func(table *TableRef) (*TableRef, bool)) (string, error) {
panic("not implemented")
LocalAlias bool
}

// Annotation is key-value annotation extracted from a DuckDB SQL comment
Expand All @@ -48,44 +172,34 @@ type Annotation struct {
Value string
}

// ExtractAnnotations extracts annotations from comments prefixed with '@', and optionally a value after a ':'.
// Examples: "-- @materialize" and "-- @materialize: true".
func ExtractAnnotations() ([]*Annotation, error) {
panic("not implemented")
}

// ColumnRef has information about a column in the select list of a DuckDB SQL statement
type ColumnRef struct {
Name string
Expr string
IsAggr bool
IsStar bool
IsExclude bool
}

// ExtractColumnRefs extracts column references from the outermost SELECT of a DuckDB SQL statement
func ExtractColumnRefs(sql string) ([]*ColumnRef, error) {
panic("not implemented")
Name string
RelationName string
Expr string
IsAggr bool
IsStar bool
IsExclude bool
}

// queryString runs a DuckDB query and returns the result as a scalar string
func queryString(qry string, args ...any) (string, error) {
func queryString(qry string, args ...any) ([]byte, error) {
rows, err := query(qry, args...)
if err != nil {
return "", err
return nil, err
}

var res string
var res []byte
if rows.Next() {
err := rows.Scan(&res)
if err != nil {
return "", err
return nil, err
}
}

err = rows.Close()
if err != nil {
return "", err
return nil, err
}

return res, nil
Expand Down
Loading

0 comments on commit b6f4ee3

Please sign in to comment.