Skip to content

Commit 02343d9

Browse files
authored
Add Format function to reconstruct SQL from AST (#50)
1 parent 0085168 commit 02343d9

File tree

5 files changed

+342
-0
lines changed

5 files changed

+342
-0
lines changed

internal/format/expressions.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package format
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/sqlc-dev/doubleclick/ast"
8+
)
9+
10+
// Expression formats an expression.
11+
func Expression(sb *strings.Builder, expr ast.Expression) {
12+
if expr == nil {
13+
return
14+
}
15+
16+
switch e := expr.(type) {
17+
case *ast.Literal:
18+
formatLiteral(sb, e)
19+
case *ast.Identifier:
20+
formatIdentifier(sb, e)
21+
case *ast.TableIdentifier:
22+
formatTableIdentifier(sb, e)
23+
case *ast.FunctionCall:
24+
formatFunctionCall(sb, e)
25+
case *ast.BinaryExpr:
26+
formatBinaryExpr(sb, e)
27+
case *ast.UnaryExpr:
28+
formatUnaryExpr(sb, e)
29+
case *ast.Asterisk:
30+
formatAsterisk(sb, e)
31+
case *ast.AliasedExpr:
32+
formatAliasedExpr(sb, e)
33+
default:
34+
// Fallback for unhandled expressions
35+
sb.WriteString(fmt.Sprintf("%v", expr))
36+
}
37+
}
38+
39+
// formatLiteral formats a literal value.
40+
func formatLiteral(sb *strings.Builder, lit *ast.Literal) {
41+
switch lit.Type {
42+
case ast.LiteralString:
43+
sb.WriteString("'")
44+
// Escape single quotes in the string
45+
s := lit.Value.(string)
46+
s = strings.ReplaceAll(s, "'", "''")
47+
sb.WriteString(s)
48+
sb.WriteString("'")
49+
case ast.LiteralInteger:
50+
switch v := lit.Value.(type) {
51+
case int64:
52+
sb.WriteString(fmt.Sprintf("%d", v))
53+
case uint64:
54+
sb.WriteString(fmt.Sprintf("%d", v))
55+
default:
56+
sb.WriteString(fmt.Sprintf("%v", lit.Value))
57+
}
58+
case ast.LiteralFloat:
59+
sb.WriteString(fmt.Sprintf("%v", lit.Value))
60+
case ast.LiteralBoolean:
61+
if lit.Value.(bool) {
62+
sb.WriteString("true")
63+
} else {
64+
sb.WriteString("false")
65+
}
66+
case ast.LiteralNull:
67+
sb.WriteString("NULL")
68+
case ast.LiteralArray:
69+
formatArrayLiteral(sb, lit.Value)
70+
case ast.LiteralTuple:
71+
formatTupleLiteral(sb, lit.Value)
72+
default:
73+
sb.WriteString(fmt.Sprintf("%v", lit.Value))
74+
}
75+
}
76+
77+
// formatArrayLiteral formats an array literal.
78+
func formatArrayLiteral(sb *strings.Builder, val interface{}) {
79+
sb.WriteString("[")
80+
exprs, ok := val.([]ast.Expression)
81+
if ok {
82+
for i, e := range exprs {
83+
if i > 0 {
84+
sb.WriteString(", ")
85+
}
86+
Expression(sb, e)
87+
}
88+
}
89+
sb.WriteString("]")
90+
}
91+
92+
// formatTupleLiteral formats a tuple literal.
93+
func formatTupleLiteral(sb *strings.Builder, val interface{}) {
94+
sb.WriteString("(")
95+
exprs, ok := val.([]ast.Expression)
96+
if ok {
97+
for i, e := range exprs {
98+
if i > 0 {
99+
sb.WriteString(", ")
100+
}
101+
Expression(sb, e)
102+
}
103+
}
104+
sb.WriteString(")")
105+
}
106+
107+
// formatIdentifier formats an identifier.
108+
func formatIdentifier(sb *strings.Builder, id *ast.Identifier) {
109+
sb.WriteString(id.Name())
110+
}
111+
112+
// formatTableIdentifier formats a table identifier.
113+
func formatTableIdentifier(sb *strings.Builder, t *ast.TableIdentifier) {
114+
if t.Database != "" {
115+
sb.WriteString(t.Database)
116+
sb.WriteString(".")
117+
}
118+
sb.WriteString(t.Table)
119+
}
120+
121+
// formatFunctionCall formats a function call.
122+
func formatFunctionCall(sb *strings.Builder, fn *ast.FunctionCall) {
123+
sb.WriteString(fn.Name)
124+
sb.WriteString("(")
125+
if fn.Distinct {
126+
sb.WriteString("DISTINCT ")
127+
}
128+
for i, arg := range fn.Arguments {
129+
if i > 0 {
130+
sb.WriteString(", ")
131+
}
132+
Expression(sb, arg)
133+
}
134+
sb.WriteString(")")
135+
}
136+
137+
// formatBinaryExpr formats a binary expression.
138+
func formatBinaryExpr(sb *strings.Builder, expr *ast.BinaryExpr) {
139+
Expression(sb, expr.Left)
140+
sb.WriteString(" ")
141+
sb.WriteString(expr.Op)
142+
sb.WriteString(" ")
143+
Expression(sb, expr.Right)
144+
}
145+
146+
// formatUnaryExpr formats a unary expression.
147+
func formatUnaryExpr(sb *strings.Builder, expr *ast.UnaryExpr) {
148+
sb.WriteString(expr.Op)
149+
Expression(sb, expr.Operand)
150+
}
151+
152+
// formatAsterisk formats an asterisk.
153+
func formatAsterisk(sb *strings.Builder, a *ast.Asterisk) {
154+
if a.Table != "" {
155+
sb.WriteString(a.Table)
156+
sb.WriteString(".")
157+
}
158+
sb.WriteString("*")
159+
}
160+
161+
// formatAliasedExpr formats an aliased expression.
162+
func formatAliasedExpr(sb *strings.Builder, a *ast.AliasedExpr) {
163+
Expression(sb, a.Expr)
164+
sb.WriteString(" AS ")
165+
sb.WriteString(a.Alias)
166+
}

internal/format/format.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Package format provides SQL formatting for ClickHouse AST.
2+
package format
3+
4+
import (
5+
"strings"
6+
7+
"github.com/sqlc-dev/doubleclick/ast"
8+
)
9+
10+
// Format returns the SQL string representation of the statements.
11+
func Format(stmts []ast.Statement) string {
12+
var sb strings.Builder
13+
for i, stmt := range stmts {
14+
if i > 0 {
15+
sb.WriteString("\n")
16+
}
17+
Statement(&sb, stmt)
18+
sb.WriteString(";")
19+
}
20+
return sb.String()
21+
}
22+
23+
// Statement formats a single statement.
24+
func Statement(sb *strings.Builder, stmt ast.Statement) {
25+
if stmt == nil {
26+
return
27+
}
28+
29+
switch s := stmt.(type) {
30+
case *ast.SelectWithUnionQuery:
31+
formatSelectWithUnionQuery(sb, s)
32+
case *ast.SelectQuery:
33+
formatSelectQuery(sb, s)
34+
default:
35+
// For now, only handle SELECT statements
36+
}
37+
}

internal/format/statements.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package format
2+
3+
import (
4+
"strings"
5+
6+
"github.com/sqlc-dev/doubleclick/ast"
7+
)
8+
9+
// formatSelectWithUnionQuery formats a SELECT with UNION query.
10+
func formatSelectWithUnionQuery(sb *strings.Builder, q *ast.SelectWithUnionQuery) {
11+
for i, sel := range q.Selects {
12+
if i > 0 {
13+
sb.WriteString(" UNION ")
14+
if len(q.UnionModes) > i-1 && q.UnionModes[i-1] == "ALL" {
15+
sb.WriteString("ALL ")
16+
} else if len(q.UnionModes) > i-1 && q.UnionModes[i-1] == "DISTINCT" {
17+
sb.WriteString("DISTINCT ")
18+
}
19+
}
20+
Statement(sb, sel)
21+
}
22+
}
23+
24+
// formatSelectQuery formats a SELECT query.
25+
func formatSelectQuery(sb *strings.Builder, q *ast.SelectQuery) {
26+
sb.WriteString("SELECT ")
27+
28+
if q.Distinct {
29+
sb.WriteString("DISTINCT ")
30+
}
31+
32+
// Format columns
33+
for i, col := range q.Columns {
34+
if i > 0 {
35+
sb.WriteString(", ")
36+
}
37+
Expression(sb, col)
38+
}
39+
40+
// Format FROM clause
41+
if q.From != nil {
42+
sb.WriteString(" FROM ")
43+
formatTablesInSelectQuery(sb, q.From)
44+
}
45+
46+
// Format WHERE clause
47+
if q.Where != nil {
48+
sb.WriteString(" WHERE ")
49+
Expression(sb, q.Where)
50+
}
51+
52+
// Format GROUP BY clause
53+
if len(q.GroupBy) > 0 {
54+
sb.WriteString(" GROUP BY ")
55+
for i, expr := range q.GroupBy {
56+
if i > 0 {
57+
sb.WriteString(", ")
58+
}
59+
Expression(sb, expr)
60+
}
61+
}
62+
63+
// Format HAVING clause
64+
if q.Having != nil {
65+
sb.WriteString(" HAVING ")
66+
Expression(sb, q.Having)
67+
}
68+
69+
// Format ORDER BY clause
70+
if len(q.OrderBy) > 0 {
71+
sb.WriteString(" ORDER BY ")
72+
for i, elem := range q.OrderBy {
73+
if i > 0 {
74+
sb.WriteString(", ")
75+
}
76+
formatOrderByElement(sb, elem)
77+
}
78+
}
79+
80+
// Format LIMIT clause
81+
if q.Limit != nil {
82+
sb.WriteString(" LIMIT ")
83+
Expression(sb, q.Limit)
84+
}
85+
}
86+
87+
// formatTablesInSelectQuery formats the FROM clause tables.
88+
func formatTablesInSelectQuery(sb *strings.Builder, t *ast.TablesInSelectQuery) {
89+
for i, elem := range t.Tables {
90+
if i > 0 {
91+
// TODO: Handle JOINs properly
92+
sb.WriteString(", ")
93+
}
94+
formatTablesInSelectQueryElement(sb, elem)
95+
}
96+
}
97+
98+
// formatTablesInSelectQueryElement formats a single table element.
99+
func formatTablesInSelectQueryElement(sb *strings.Builder, t *ast.TablesInSelectQueryElement) {
100+
if t.Table != nil {
101+
formatTableExpression(sb, t.Table)
102+
}
103+
}
104+
105+
// formatTableExpression formats a table expression.
106+
func formatTableExpression(sb *strings.Builder, t *ast.TableExpression) {
107+
Expression(sb, t.Table)
108+
if t.Alias != "" {
109+
sb.WriteString(" AS ")
110+
sb.WriteString(t.Alias)
111+
}
112+
}
113+
114+
// formatOrderByElement formats an ORDER BY element.
115+
func formatOrderByElement(sb *strings.Builder, o *ast.OrderByElement) {
116+
Expression(sb, o.Expression)
117+
if o.Descending {
118+
sb.WriteString(" DESC")
119+
}
120+
}

parser/format.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package parser
2+
3+
import (
4+
"github.com/sqlc-dev/doubleclick/ast"
5+
"github.com/sqlc-dev/doubleclick/internal/format"
6+
)
7+
8+
// Format returns the SQL string representation of the statements.
9+
func Format(stmts []ast.Statement) string {
10+
return format.Format(stmts)
11+
}

parser/parser_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,14 @@ func TestParser(t *testing.T) {
202202
}
203203
}
204204

205+
// Check Format output for 00007_array test
206+
if entry.Name() == "00007_array" {
207+
formatted := parser.Format(stmts)
208+
if formatted != query {
209+
t.Errorf("Format output mismatch\nQuery: %s\nFormatted: %s", query, formatted)
210+
}
211+
}
212+
205213
// If we get here with a todo test and -check-skipped is set, the test passes!
206214
// Automatically remove the todo flag from metadata.json
207215
if metadata.Todo && *checkSkipped {

0 commit comments

Comments
 (0)