diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..002a963c89 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -169,11 +169,18 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) } errors = append(errors, errs...) } - refs = uniqueParamRefs(refs, dollar) + refs = numberParamRefs(refs, dollar) if c.conf.Engine == config.EngineMySQL || !dollar { - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Location < refs[j].ref.Location }) + sort.SliceStable(refs, func(i, j int) bool { + return refs[i].ref.Location < refs[j].ref.Location + }) } else { - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + sort.SliceStable(refs, func(i, j int) bool { + if refs[i].ref.Number == refs[j].ref.Number { + return refs[i].ref.Location < refs[j].ref.Location + } + return refs[i].ref.Number < refs[j].ref.Number + }) } raw, embeds := rewrite.Embeds(raw) qc, err := c.buildQueryCatalog(c.catalog, raw.Stmt, embeds) diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 1a95b586f4..e1997a3baa 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -24,6 +24,7 @@ type Parser interface { Parse(io.Reader) ([]ast.Statement, error) CommentSyntax() source.CommentSyntax IsReservedKeyword(string) bool + TypeName(ns, name string) string } func (c *Compiler) parseCatalog(schemas []string) error { diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 8199addd33..066f63c9d2 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -111,6 +111,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } case *ast.UpdateStmt: + var targetRV *ast.RangeVar + for _, relation := range n.Relations.Items { + rv, ok := relation.(*ast.RangeVar) + if !ok { + continue + } + targetRV = rv + break + } + for _, item := range n.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -120,13 +130,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - for _, relation := range n.Relations.Items { - rv, ok := relation.(*ast.RangeVar) - if !ok { - continue - } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv}) - } + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: targetRV}) p.seen[ref.Location] = struct{}{} } if n.LimitCount != nil { @@ -140,6 +144,11 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = node case *ast.SelectStmt: + if n.FromClause != nil && len(n.FromClause.Items) > 0 { + if rv, ok := n.FromClause.Items[0].(*ast.RangeVar); ok { + p.rangeVar = rv + } + } if n.LimitCount != nil { p.limitCount = n.LimitCount } diff --git a/internal/compiler/find_params_test.go b/internal/compiler/find_params_test.go new file mode 100644 index 0000000000..a31e72ccda --- /dev/null +++ b/internal/compiler/find_params_test.go @@ -0,0 +1,32 @@ +package compiler + +import ( + "testing" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestFindParametersSelectStmtUsesFromRangeVarForWhereParams(t *testing.T) { + t.Parallel() + + tableName := "solar_commcard_mapping" + refs, errs := findParameters(&ast.SelectStmt{ + FromClause: &ast.List{Items: []ast.Node{&ast.RangeVar{Relname: &tableName}}}, + WhereClause: &ast.A_Expr{ + Lexpr: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "deviceId"}}}}, + Rexpr: &ast.ParamRef{Number: 1, Location: 1}, + }, + }) + if len(errs) > 0 { + t.Fatalf("findParameters returned errors: %v", errs) + } + if len(refs) != 1 { + t.Fatalf("expected 1 ref, got %d", len(refs)) + } + if refs[0].rv == nil || refs[0].rv.Relname == nil { + t.Fatal("expected ref to carry range var") + } + if got := *refs[0].rv.Relname; got != tableName { + t.Fatalf("expected ref range var %q, got %q", tableName, got) + } +} diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 751cb3271a..374bf4f0c9 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -193,28 +193,29 @@ func rangeVars(root ast.Node) []*ast.RangeVar { return vars } -func uniqueParamRefs(in []paramRef, dollar bool) []paramRef { - m := make(map[int]bool, len(in)) - o := make([]paramRef, 0, len(in)) +func numberParamRefs(in []paramRef, dollar bool) []paramRef { + if dollar { + return in + } + + used := make(map[int]bool, len(in)) for _, v := range in { - if !m[v.ref.Number] { - m[v.ref.Number] = true - if v.ref.Number != 0 { - o = append(o, v) - } - } - } - if !dollar { - start := 1 - for _, v := range in { - if v.ref.Number == 0 { - for m[start] { - start++ - } - v.ref.Number = start - o = append(o, v) - } - } - } - return o + if v.ref.Number != 0 { + used[v.ref.Number] = true + } + } + + start := 1 + for i := range in { + if in[i].ref.Number != 0 { + continue + } + for used[start] { + start++ + } + in[i].ref.Number = start + used[start] = true + } + + return in } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b1fbb1990e..1fc661a26c 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "strconv" + "strings" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" @@ -21,6 +22,219 @@ func dataType(n *ast.TypeName) string { } } +func hasConcreteParamType(col *Column) bool { + return col != nil && col.DataType != "" && col.DataType != "any" +} + +func (comp *Compiler) paramTypeString(col *Column) string { + if !hasConcreteParamType(col) { + return "any" + } + + arraySuffix := strings.Repeat("[]", col.ArrayDims) + if col.Type != nil && col.Type.Name != "" { + return comp.parser.TypeName(col.Type.Schema, col.Type.Name) + arraySuffix + } + + if rel, err := ParseRelationString(col.DataType); err == nil && rel.Catalog == "" { + return comp.parser.TypeName(rel.Schema, rel.Name) + arraySuffix + } + + return col.DataType + arraySuffix +} + +func compatibleParamTypes(a, b *Column) bool { + if !hasConcreteParamType(a) || !hasConcreteParamType(b) { + return true + } + return a.DataType == b.DataType && + a.Unsigned == b.Unsigned && + a.IsArray == b.IsArray && + a.ArrayDims == b.ArrayDims +} + +func sameTypeName(a, b *ast.TypeName) bool { + if a == nil || b == nil { + return a == nil && b == nil + } + return a.Catalog == b.Catalog && a.Schema == b.Schema && a.Name == b.Name +} + +func matchingFuncCallOverloads(c *catalog.Catalog, call *ast.FuncCall) []catalog.Function { + funs, err := c.ListFuncsByName(call.Func) + if err != nil { + return nil + } + + var positional []ast.Node + var named []*ast.NamedArgExpr + if call.Args != nil { + for _, arg := range call.Args.Items { + if narg, ok := arg.(*ast.NamedArgExpr); ok { + named = append(named, narg) + continue + } + if len(named) > 0 { + return nil + } + positional = append(positional, arg) + } + } + + var matches []catalog.Function + for _, fun := range funs { + args := fun.InArgs() + var defaults int + var variadic bool + known := map[string]struct{}{} + for _, arg := range args { + if arg.HasDefault { + defaults += 1 + } + if arg.Mode == ast.FuncParamVariadic { + variadic = true + defaults += 1 + } + if arg.Name != "" { + known[arg.Name] = struct{}{} + } + } + + argc := len(named) + len(positional) + if variadic { + if argc < (len(args) - defaults) { + continue + } + } else { + if argc > len(args) || argc < (len(args)-defaults) { + continue + } + } + + var unknownArgName bool + for _, expr := range named { + if expr.Name != nil { + if _, found := known[*expr.Name]; !found { + unknownArgName = true + } + } + } + if unknownArgName { + continue + } + + matches = append(matches, fun) + } + + return matches +} + +func stableFuncCallArgType(c *catalog.Catalog, call *ast.FuncCall, argIndex int, argName string) *ast.TypeName { + var stable *ast.TypeName + var seen bool + + for _, fun := range matchingFuncCallOverloads(c, call) { + args := fun.InArgs() + var current *ast.TypeName + if argName == "" { + if argIndex >= len(args) { + return nil + } + current = args[argIndex].Type + } else { + for _, arg := range args { + if arg.Name == argName { + current = arg.Type + break + } + } + if current == nil { + return nil + } + } + + if !seen { + stable = current + seen = true + continue + } + if !sameTypeName(stable, current) { + return nil + } + } + + return stable +} + +func resolvedFuncCallArgType(fun *catalog.Function, argIndex int, argName string) *ast.TypeName { + if fun == nil { + return nil + } + if argName == "" { + if argIndex < len(fun.Args) { + return fun.Args[argIndex].Type + } + return nil + } + for _, arg := range fun.Args { + if arg.Name == argName { + return arg.Type + } + } + return nil +} + +func mergeResolvedParam(existing, incoming Parameter) Parameter { + if existing.Column == nil { + return incoming + } + if incoming.Column == nil { + return existing + } + + base := existing + other := incoming + if hasConcreteParamType(incoming.Column) && !hasConcreteParamType(existing.Column) { + base = incoming + other = existing + } + + col := *base.Column + if col.Name == "" { + col.Name = other.Column.Name + } + if col.OriginalName == "" { + col.OriginalName = other.Column.OriginalName + } + if col.Table == nil { + col.Table = other.Column.Table + } + if col.Type == nil { + col.Type = other.Column.Type + } + if col.Length == nil { + col.Length = other.Column.Length + } + col.IsNamedParam = col.IsNamedParam || other.Column.IsNamedParam + col.IsSqlcSlice = col.IsSqlcSlice || other.Column.IsSqlcSlice + + base.Column = &col + return base +} + +func (comp *Compiler) incompatibleParamRefError(ref paramRef, existing, incoming Parameter) error { + return &sqlerr.Error{ + Code: "42P08", + Message: fmt.Sprintf( + "parameter $%d has incompatible types: %s, %s", + ref.ref.Number, + comp.paramTypeString(existing.Column), + comp.paramTypeString(incoming.Column), + ), + Location: ref.ref.Location, + } +} + func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { c := comp.catalog @@ -98,11 +312,29 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } var a []Parameter + seen := map[int]int{} + paramCounts := map[int]int{} + for _, ref := range args { + paramCounts[ref.ref.Number] += 1 + } + + addParam := func(ref paramRef, p Parameter) error { + if idx, ok := seen[p.Number]; ok { + if !compatibleParamTypes(a[idx].Column, p.Column) { + return comp.incompatibleParamRefError(ref, a[idx], p) + } + a[idx] = mergeResolvedParam(a[idx], p) + return nil + } + seen[p.Number] = len(a) + a = append(a, p) + return nil + } - addUnknownParam := func(ref paramRef) { + addUnknownParam := func(ref paramRef) error { defaultP := named.NewInferredParam(ref.name, false) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ + return addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -112,13 +344,64 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, }) } + addColumnParam := func(ref paramRef, key string, location int) error { + var schema, rel string + if defaultTable != nil { + schema = defaultTable.Schema + rel = defaultTable.Name + } + if ref.rv != nil { + fqn, err := ParseTableName(ref.rv) + if err != nil { + return err + } + schema = fqn.Schema + rel = fqn.Name + } + if schema == "" { + schema = c.DefaultSchema + } + + tableMap, ok := typeMap[schema][rel] + if !ok { + return sqlerr.RelationNotFound(rel) + } + + if c, ok := tableMap[key]; ok { + defaultP := named.NewInferredParam(key, c.IsNotNull) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) + return addParam(ref, Parameter{ + Number: ref.ref.Number, + Column: &Column{ + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Table: &ast.TableName{Schema: schema, Name: rel}, + Length: c.Length, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + }, + }) + } + + return &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("column %q does not exist", key), + Location: location, + } + } + for _, ref := range args { switch n := ref.parent.(type) { case *limitOffset: defaultP := named.NewInferredParam("offset", true) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -126,12 +409,14 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, NotNull: p.NotNull(), IsNamedParam: isNamed, }, - }) + }); err != nil { + return nil, err + } case *limitCount: defaultP := named.NewInferredParam("limit", true) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -139,7 +424,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, NotNull: p.NotNull(), IsNamedParam: isNamed, }, - }) + }); err != nil { + return nil, err + } case *ast.A_Expr: // TODO: While this works for a wide range of simple expressions, @@ -164,7 +451,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, defaultP := named.NewParam("") p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -173,7 +460,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, NotNull: p.NotNull(), IsSqlcSlice: p.IsSqlcSlice(), }, - }) + }); err != nil { + return nil, err + } continue } @@ -196,7 +485,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } search := tables - if alias != "" { + if alias == "" && ref.rv != nil { + fqn, err := ParseTableName(ref.rv) + if err != nil { + return nil, err + } + search = []*ast.TableName{fqn} + } else if alias != "" { if original, ok := aliasMap[alias]; ok { search = []*ast.TableName{original} } else { @@ -231,7 +526,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -246,7 +541,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, IsNamedParam: isNamed, IsSqlcSlice: p.IsSqlcSlice(), }, - }) + }); err != nil { + return nil, err + } } } @@ -291,14 +588,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) var namePrefix string if !isNamed { - if ref.ref == n.Left { + switch ref.ref { + case n.Left: namePrefix = "from_" - } else if ref.ref == n.Right { + case n.Right: namePrefix = "to_" } } - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: namePrefix + p.Name(), @@ -311,13 +609,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, IsNamedParam: isNamed, IsSqlcSlice: p.IsSqlcSlice(), }, - }) + }); err != nil { + return nil, err + } } } case *ast.FuncCall: - fun, err := c.ResolveFuncCall(n) - if err != nil { + fun, resolveErr := c.ResolveFuncCall(n) + if resolveErr != nil { // Synthesize a function on the fly to avoid returning with an error // for an unknown Postgres function (e.g. defined in an extension) var args []*catalog.Argument @@ -374,7 +674,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, defaultP := named.NewInferredParam(defaultName, false) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) added = true - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -383,7 +683,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, NotNull: p.NotNull(), IsSqlcSlice: p.IsSqlcSlice(), }, - }) + }); err != nil { + return nil, err + } continue } @@ -393,22 +695,20 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if argName == "" { if i < len(fun.Args) { paramName = fun.Args[i].Name - paramType = fun.Args[i].Type } } else { paramName = argName - for _, arg := range fun.Args { - if arg.Name == argName { - paramType = arg.Type - } - } - if paramType == nil { - panic(fmt.Sprintf("named argument %s has no type", paramName)) - } } if paramName == "" { paramName = funcName } + if resolveErr == nil { + if paramCounts[ref.ref.Number] > 1 { + paramType = stableFuncCallArgType(c, n, i, argName) + } else { + paramType = resolvedFuncCallArgType(fun, i, argName) + } + } if paramType == nil { paramType = &ast.TypeName{Name: ""} } @@ -416,7 +716,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, defaultP := named.NewInferredParam(paramName, true) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) added = true - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), @@ -425,12 +725,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, IsNamedParam: isNamed, IsSqlcSlice: p.IsSqlcSlice(), }, - }) + }); err != nil { + return nil, err + } } if fun.ReturnType == nil { if !added { - addUnknownParam(ref) + if err := addUnknownParam(ref); err != nil { + return nil, err + } } continue } @@ -442,7 +746,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, }) if err != nil { if !added { - addUnknownParam(ref) + if err := addUnknownParam(ref); err != nil { + return nil, err + } } continue } @@ -455,56 +761,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if n.Name == nil { return nil, fmt.Errorf("*ast.ResTarget has nil name") } - key := *n.Name - - var schema, rel string - // TODO: Deprecate defaultTable - if defaultTable != nil { - schema = defaultTable.Schema - rel = defaultTable.Name - } - if ref.rv != nil { - fqn, err := ParseTableName(ref.rv) - if err != nil { - return nil, err - } - schema = fqn.Schema - rel = fqn.Name - } - if schema == "" { - schema = c.DefaultSchema - } - - tableMap, ok := typeMap[schema][rel] - if !ok { - return nil, sqlerr.RelationNotFound(rel) + if err := addColumnParam(ref, *n.Name, n.Location); err != nil { + return nil, err } - if c, ok := tableMap[key]; ok { - defaultP := named.NewInferredParam(key, c.IsNotNull) - p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ - Number: ref.ref.Number, - Column: &Column{ - Name: p.Name(), - OriginalName: c.Name, - DataType: dataType(&c.Type), - NotNull: p.NotNull(), - Unsigned: c.IsUnsigned, - IsArray: c.IsArray, - ArrayDims: c.ArrayDims, - Table: &ast.TableName{Schema: schema, Name: rel}, - Length: c.Length, - IsNamedParam: isNamed, - IsSqlcSlice: p.IsSqlcSlice(), - }, - }) - } else { - return nil, &sqlerr.Error{ - Code: "42703", - Message: fmt.Sprintf("column %q does not exist", key), - Location: n.Location, - } + case *ast.String: + if err := addColumnParam(ref, n.Str, n.Pos()); err != nil { + return nil, err } case *ast.TypeCast: @@ -517,13 +780,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, col.Name = p.Name() col.NotNull = p.NotNull() - a = append(a, Parameter{ + if err := addParam(ref, Parameter{ Number: ref.ref.Number, Column: col, - }) + }); err != nil { + return nil, err + } case *ast.ParamRef: - a = append(a, Parameter{Number: ref.ref.Number}) + if err := addParam(ref, Parameter{Number: ref.ref.Number}); err != nil { + return nil, err + } case *ast.In: if n == nil || n.List == nil { @@ -531,11 +798,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - number := 0 - if pr, ok := n.List[0].(*ast.ParamRef); ok { - number = pr.Number - } - location := 0 var key, alias string var items []string @@ -594,8 +856,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) - a = append(a, Parameter{ - Number: number, + if err := addParam(ref, Parameter{ + Number: ref.ref.Number, Column: &Column{ Name: p.Name(), OriginalName: c.Name, @@ -608,7 +870,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, IsNamedParam: isNamed, IsSqlcSlice: p.IsSqlcSlice(), }, - }) + }); err != nil { + return nil, err + } } } } @@ -630,7 +894,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, default: slog.Debug("unsupported reference type", "type", fmt.Sprintf("%T", n)) - addUnknownParam(ref) + if err := addUnknownParam(ref); err != nil { + return nil, err + } } } return a, nil diff --git a/internal/compiler/resolve_test.go b/internal/compiler/resolve_test.go new file mode 100644 index 0000000000..b2fc72d60d --- /dev/null +++ b/internal/compiler/resolve_test.go @@ -0,0 +1,209 @@ +package compiler + +import ( + "testing" + + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" + "github.com/sqlc-dev/sqlc/internal/sql/named" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +func TestParamTypeString(t *testing.T) { + t.Parallel() + + t.Run("postgresql type aliases", func(t *testing.T) { + t.Parallel() + comp := &Compiler{parser: postgresql.NewParser()} + + got := comp.paramTypeString(&Column{DataType: "pg_catalog.int4", ArrayDims: 2}) + if got != "integer[][]" { + t.Fatalf("expected integer[][], got %q", got) + } + }) + + t.Run("structured type metadata is preferred", func(t *testing.T) { + t.Parallel() + comp := &Compiler{parser: postgresql.NewParser()} + + got := comp.paramTypeString(&Column{ + DataType: "catalog.pg_catalog.int4", + Type: &ast.TypeName{Schema: "pg_catalog", Name: "bpchar"}, + }) + if got != "character" { + t.Fatalf("expected character, got %q", got) + } + }) + + t.Run("sqlite keeps names unchanged", func(t *testing.T) { + t.Parallel() + comp := &Compiler{parser: sqlite.NewParser()} + + got := comp.paramTypeString(&Column{DataType: "custom_type", ArrayDims: 1}) + if got != "custom_type[]" { + t.Fatalf("expected custom_type[], got %q", got) + } + }) +} + +func TestIncompatibleParamRefErrorFormatsTypeNames(t *testing.T) { + t.Parallel() + + comp := &Compiler{parser: postgresql.NewParser()} + err := comp.incompatibleParamRefError(paramRef{ref: &ast.ParamRef{Number: 1}}, Parameter{ + Number: 1, + Column: &Column{DataType: "text"}, + }, Parameter{ + Number: 1, + Column: &Column{DataType: "pg_catalog.int4"}, + }) + + sqlErr, ok := err.(*sqlerr.Error) + if !ok { + t.Fatalf("expected *sqlerr.Error, got %T", err) + } + if sqlErr.Message != "parameter $1 has incompatible types: text, integer" { + t.Fatalf("unexpected message: %q", sqlErr.Message) + } +} + +func TestMergeResolvedParamKeepsFirstNameForCompatibleTypes(t *testing.T) { + t.Parallel() + + merged := mergeResolvedParam( + Parameter{Number: 1, Column: &Column{Name: "user", DataType: "text"}}, + Parameter{Number: 1, Column: &Column{Name: "student_user", DataType: "text"}}, + ) + + if merged.Column == nil { + t.Fatal("expected merged column") + } + if merged.Column.Name != "user" { + t.Fatalf("expected first inferred name to win, got %q", merged.Column.Name) + } +} + +func TestResolvedFuncCallArgType(t *testing.T) { + t.Parallel() + + fun := &catalog.Function{Args: []*catalog.Argument{ + {Name: "lhs", Type: &ast.TypeName{Name: "int8"}}, + {Name: "rhs", Type: &ast.TypeName{Name: "text"}}, + }} + + if got := resolvedFuncCallArgType(fun, 0, ""); got == nil || got.Name != "int8" { + t.Fatalf("expected positional arg type int8, got %#v", got) + } + if got := resolvedFuncCallArgType(fun, 0, "rhs"); got == nil || got.Name != "text" { + t.Fatalf("expected named arg type text, got %#v", got) + } + if got := resolvedFuncCallArgType(fun, 2, ""); got != nil { + t.Fatalf("expected nil for out-of-range positional arg, got %#v", got) + } +} + +func TestResolveCatalogRefsInsertTargetStringInfersColumnName(t *testing.T) { + t.Parallel() + + comp := &Compiler{parser: postgresql.NewParser(), catalog: postgresql.NewCatalog()} + + var schema *catalog.Schema + for _, s := range comp.catalog.Schemas { + if s.Name == comp.catalog.DefaultSchema { + schema = s + break + } + } + if schema == nil { + t.Fatal("default schema not found") + } + + tableName := "solar_commcard_mapping" + schema.Tables = append(schema.Tables, &catalog.Table{ + Rel: &ast.TableName{Schema: schema.Name, Name: tableName}, + Columns: []*catalog.Column{&catalog.Column{ + Name: "deviceId", + Type: ast.TypeName{Schema: "pg_catalog", Name: "int8"}, + IsNotNull: true, + }}, + }) + + rv := &ast.RangeVar{Relname: &tableName} + params, err := comp.resolveCatalogRefs(nil, []*ast.RangeVar{rv}, []paramRef{{ + parent: &ast.String{Str: "deviceId"}, + rv: rv, + ref: &ast.ParamRef{Number: 1}, + }}, named.NewParamSet(nil, true), nil) + if err != nil { + t.Fatalf("resolveCatalogRefs returned error: %v", err) + } + if len(params) != 1 { + t.Fatalf("expected 1 param, got %d", len(params)) + } + if params[0].Column == nil { + t.Fatal("expected resolved column metadata") + } + if params[0].Column.Name != "deviceId" { + t.Fatalf("expected inferred name deviceId, got %q", params[0].Column.Name) + } + if params[0].Column.OriginalName != "deviceId" { + t.Fatalf("expected original name deviceId, got %q", params[0].Column.OriginalName) + } + if params[0].Column.DataType != "pg_catalog.int8" { + t.Fatalf("expected data type pg_catalog.int8, got %q", params[0].Column.DataType) + } + if params[0].Column.Table == nil || params[0].Column.Table.Name != tableName { + t.Fatalf("expected table %q, got %#v", tableName, params[0].Column.Table) + } +} + +func TestResolveCatalogRefsAExprUsesScopedRangeVar(t *testing.T) { + t.Parallel() + + comp := &Compiler{parser: postgresql.NewParser(), catalog: postgresql.NewCatalog()} + + var schema *catalog.Schema + for _, s := range comp.catalog.Schemas { + if s.Name == comp.catalog.DefaultSchema { + schema = s + break + } + } + if schema == nil { + t.Fatal("default schema not found") + } + + tableName := "solar_commcard_mapping" + schema.Tables = append(schema.Tables, &catalog.Table{ + Rel: &ast.TableName{Schema: schema.Name, Name: tableName}, + Columns: []*catalog.Column{{ + Name: "deviceId", + Type: ast.TypeName{Schema: "pg_catalog", Name: "int8"}, + IsNotNull: true, + }}, + }) + + rv := &ast.RangeVar{Relname: &tableName} + params, err := comp.resolveCatalogRefs(nil, []*ast.RangeVar{rv, rv}, []paramRef{{ + parent: &ast.A_Expr{ + Lexpr: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "deviceId"}}}}, + Rexpr: &ast.ParamRef{Number: 1}, + }, + rv: rv, + ref: &ast.ParamRef{Number: 1}, + }}, named.NewParamSet(nil, true), nil) + if err != nil { + t.Fatalf("resolveCatalogRefs returned error: %v", err) + } + if len(params) != 1 { + t.Fatalf("expected 1 param, got %d", len(params)) + } + if params[0].Column == nil { + t.Fatal("expected resolved column metadata") + } + if params[0].Column.Name != "deviceId" { + t.Fatalf("expected inferred name deviceId, got %q", params[0].Column.Name) + } +} diff --git a/internal/endtoend/testdata/insert_select_param/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/insert_select_param/postgresql/pgx/go/query.sql.go index 999d9b755d..2b1cad2faa 100644 --- a/internal/endtoend/testdata/insert_select_param/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/insert_select_param/postgresql/pgx/go/query.sql.go @@ -7,8 +7,6 @@ package querytest import ( "context" - - "github.com/jackc/pgx/v5/pgtype" ) const insertSelect = `-- name: InsertSelect :exec @@ -19,8 +17,8 @@ WHERE name = $2 ` type InsertSelectParams struct { - ID pgtype.Int8 - Name pgtype.Text + ID int64 + Name string } func (q *Queries) InsertSelect(ctx context.Context, arg InsertSelectParams) error { diff --git a/internal/endtoend/testdata/invalid_reused_param_type/postgresql/query.sql b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/query.sql new file mode 100644 index 0000000000..f11799c673 --- /dev/null +++ b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/query.sql @@ -0,0 +1,7 @@ +-- name: CreateAuthor :one +INSERT INTO authors ( + name, bio, age +) VALUES ( + $1, $1, $1 +) +RETURNING *; \ No newline at end of file diff --git a/internal/endtoend/testdata/invalid_reused_param_type/postgresql/schema.sql b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/schema.sql new file mode 100644 index 0000000000..f6cc820e4c --- /dev/null +++ b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + bio text, + age INT +); \ No newline at end of file diff --git a/internal/endtoend/testdata/invalid_reused_param_type/postgresql/sqlc.json b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/sqlc.json new file mode 100644 index 0000000000..beb25b20e7 --- /dev/null +++ b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "name": "querytest", + "path": "go", + "schema": "schema.sql", + "queries": "query.sql", + "engine": "postgresql" + } + ] +} \ No newline at end of file diff --git a/internal/endtoend/testdata/invalid_reused_param_type/postgresql/stderr/base.txt b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/stderr/base.txt new file mode 100644 index 0000000000..60cbd3bd4f --- /dev/null +++ b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/stderr/base.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:5:10: parameter $1 has incompatible types: text, integer \ No newline at end of file diff --git a/internal/endtoend/testdata/invalid_reused_param_type/postgresql/stderr/managed-db.txt b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/stderr/managed-db.txt new file mode 100644 index 0000000000..28b7d3ea9f --- /dev/null +++ b/internal/endtoend/testdata/invalid_reused_param_type/postgresql/stderr/managed-db.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:5:10: inconsistent types deduced for parameter $1 \ No newline at end of file diff --git a/internal/endtoend/testdata/nested_select/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/nested_select/postgresql/pgx/go/query.sql.go index 06698dbad4..daf563abe7 100644 --- a/internal/endtoend/testdata/nested_select/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/nested_select/postgresql/pgx/go/query.sql.go @@ -7,8 +7,6 @@ package querytest import ( "context" - - "github.com/jackc/pgx/v5/pgtype" ) const nestedSelect = `-- name: NestedSelect :one @@ -26,7 +24,7 @@ INNER JOIN test t USING (id, update_time) type NestedSelectParams struct { IDs []int64 - StartTime pgtype.Int8 + StartTime int64 } type NestedSelectRow struct { diff --git a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go index a04cba9fe0..a6dce2d40a 100644 --- a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go +++ b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/go/query.sql.go @@ -7,6 +7,7 @@ package querytest import ( "context" + "database/sql" ) const findByID = `-- name: FindByID :many @@ -37,11 +38,16 @@ func (q *Queries) FindByID(ctx context.Context, id int32) ([]User, error) { } const findByIDAndName = `-- name: FindByIDAndName :many -SELECT id, name FROM users WHERE $1 = id AND $1 = name +SELECT id, name FROM users WHERE $1 = id AND $2 = name ` -func (q *Queries) FindByIDAndName(ctx context.Context, id int32) ([]User, error) { - rows, err := q.db.QueryContext(ctx, findByIDAndName, id) +type FindByIDAndNameParams struct { + ID int32 + Name sql.NullString +} + +func (q *Queries) FindByIDAndName(ctx context.Context, arg FindByIDAndNameParams) ([]User, error) { + rows, err := q.db.QueryContext(ctx, findByIDAndName, arg.ID, arg.Name) if err != nil { return nil, err } diff --git a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql index 50b95ea32e..807c84b4f8 100644 --- a/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql +++ b/internal/endtoend/testdata/params_placeholder_in_left_expr/postgresql/query.sql @@ -2,4 +2,4 @@ SELECT * FROM users WHERE $1 = id; -- name: FindByIDAndName :many -SELECT * FROM users WHERE $1 = id AND $1 = name; +SELECT * FROM users WHERE $1 = id AND $2 = name; diff --git a/internal/endtoend/testdata/select_subquery/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/select_subquery/postgresql/stdlib/go/query.sql.go index 1e25e8a86a..7bcf9381d3 100644 --- a/internal/endtoend/testdata/select_subquery/postgresql/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/select_subquery/postgresql/stdlib/go/query.sql.go @@ -19,8 +19,8 @@ FROM FOO WHERE a = $2 ` type SubqueryParams struct { - Column1 sql.NullString - Column2 sql.NullInt32 + Alias sql.NullString + A int32 } type SubqueryRow struct { @@ -30,7 +30,7 @@ type SubqueryRow struct { } func (q *Queries) Subquery(ctx context.Context, arg SubqueryParams) ([]SubqueryRow, error) { - rows, err := q.db.QueryContext(ctx, subquery, arg.Column1, arg.Column2) + rows, err := q.db.QueryContext(ctx, subquery, arg.Alias, arg.A) if err != nil { return nil, err } diff --git a/internal/endtoend/testdata/unnest_star/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/unnest_star/postgresql/pgx/go/query.sql.go index 9cfcf1b136..b72c8b7c96 100644 --- a/internal/endtoend/testdata/unnest_star/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/unnest_star/postgresql/pgx/go/query.sql.go @@ -7,8 +7,6 @@ package querytest import ( "context" - - "github.com/jackc/pgx/v5/pgtype" ) const getPlanItems = `-- name: GetPlanItems :many @@ -27,7 +25,7 @@ LATERAL ( type GetPlanItemsParams struct { Ids []int64 - After pgtype.Int4 + After int32 LimitCount int64 }