Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,18 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
}
errors = append(errors, errs...)
}
refs = uniqueParamRefs(refs, dollar)
refs = numberParamRefs(refs, dollar)
if c.conf.Engine == config.EngineMySQL || !dollar {
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location })
sort.SliceStable(refs, func(i, j int) bool {
return refs[i].ref.Location < refs[j].ref.Location
})
} else {
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
sort.SliceStable(refs, func(i, j int) bool {
if refs[i].ref.Number == refs[j].ref.Number {
return refs[i].ref.Location < refs[j].ref.Location
}
return refs[i].ref.Number < refs[j].ref.Number
})
}
raw, embeds := rewrite.Embeds(raw)
qc, err := c.buildQueryCatalog(c.catalog, raw.Stmt, embeds)
Expand Down
1 change: 1 addition & 0 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Parser interface {
Parse(io.Reader) ([]ast.Statement, error)
CommentSyntax() source.CommentSyntax
IsReservedKeyword(string) bool
TypeName(ns, name string) string
}

func (c *Compiler) parseCatalog(schemas []string) error {
Expand Down
23 changes: 16 additions & 7 deletions internal/compiler/find_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
}

case *ast.UpdateStmt:
var targetRV *ast.RangeVar
for _, relation := range n.Relations.Items {
rv, ok := relation.(*ast.RangeVar)
if !ok {
continue
}
targetRV = rv
break
}

for _, item := range n.TargetList.Items {
target, ok := item.(*ast.ResTarget)
if !ok {
Expand All @@ -120,13 +130,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
if !ok {
continue
}
for _, relation := range n.Relations.Items {
rv, ok := relation.(*ast.RangeVar)
if !ok {
continue
}
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
}
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: targetRV})
p.seen[ref.Location] = struct{}{}
}
if n.LimitCount != nil {
Expand All @@ -140,6 +144,11 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
p.parent = node

case *ast.SelectStmt:
if n.FromClause != nil && len(n.FromClause.Items) > 0 {
if rv, ok := n.FromClause.Items[0].(*ast.RangeVar); ok {
p.rangeVar = rv
}
}
if n.LimitCount != nil {
p.limitCount = n.LimitCount
}
Expand Down
32 changes: 32 additions & 0 deletions internal/compiler/find_params_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package compiler

import (
"testing"

"github.com/sqlc-dev/sqlc/internal/sql/ast"
)

func TestFindParametersSelectStmtUsesFromRangeVarForWhereParams(t *testing.T) {
t.Parallel()

tableName := "solar_commcard_mapping"
refs, errs := findParameters(&ast.SelectStmt{
FromClause: &ast.List{Items: []ast.Node{&ast.RangeVar{Relname: &tableName}}},
WhereClause: &ast.A_Expr{
Lexpr: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "deviceId"}}}},
Rexpr: &ast.ParamRef{Number: 1, Location: 1},
},
})
if len(errs) > 0 {
t.Fatalf("findParameters returned errors: %v", errs)
}
if len(refs) != 1 {
t.Fatalf("expected 1 ref, got %d", len(refs))
}
if refs[0].rv == nil || refs[0].rv.Relname == nil {
t.Fatal("expected ref to carry range var")
}
if got := *refs[0].rv.Relname; got != tableName {
t.Fatalf("expected ref range var %q, got %q", tableName, got)
}
}
47 changes: 24 additions & 23 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,28 +193,29 @@ func rangeVars(root ast.Node) []*ast.RangeVar {
return vars
}

func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {
m := make(map[int]bool, len(in))
o := make([]paramRef, 0, len(in))
func numberParamRefs(in []paramRef, dollar bool) []paramRef {
if dollar {
return in
}

used := make(map[int]bool, len(in))
for _, v := range in {
if !m[v.ref.Number] {
m[v.ref.Number] = true
if v.ref.Number != 0 {
o = append(o, v)
}
}
}
if !dollar {
start := 1
for _, v := range in {
if v.ref.Number == 0 {
for m[start] {
start++
}
v.ref.Number = start
o = append(o, v)
}
}
}
return o
if v.ref.Number != 0 {
used[v.ref.Number] = true
}
}

start := 1
for i := range in {
if in[i].ref.Number != 0 {
continue
}
for used[start] {
start++
}
in[i].ref.Number = start
used[start] = true
}

return in
}
Loading
Loading