diff --git a/pkg/cmd/smith/BUILD.bazel b/pkg/cmd/smith/BUILD.bazel index d52355c70bbf..735c88df12b2 100644 --- a/pkg/cmd/smith/BUILD.bazel +++ b/pkg/cmd/smith/BUILD.bazel @@ -7,7 +7,15 @@ go_library( visibility = ["//visibility:private"], deps = [ "//pkg/internal/sqlsmith", + "//pkg/keys", + "//pkg/settings/cluster", + "//pkg/sql/catalog/bootstrap", + "//pkg/sql/catalog/descpb", + "//pkg/sql/importer", + "//pkg/sql/parser", + "//pkg/sql/sem/tree", "//pkg/util/randutil", + "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", ], ) diff --git a/pkg/cmd/smith/main.go b/pkg/cmd/smith/main.go index be0f176846a4..cce6d53a5847 100644 --- a/pkg/cmd/smith/main.go +++ b/pkg/cmd/smith/main.go @@ -11,14 +11,25 @@ package main import ( + "context" gosql "database/sql" "flag" "fmt" + "io" "os" "sort" + "strings" "github.com/cockroachdb/cockroach/pkg/internal/sqlsmith" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/bootstrap" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/importer" + "github.com/cockroachdb/cockroach/pkg/sql/parser" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/util/randutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" ) @@ -43,45 +54,54 @@ Options: ` var ( - flags = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - expr = flags.Bool("expr", false, "generate expressions instead of statements") - num = flags.Int("num", 1, "number of statements / expressions to generate") - url = flags.String("url", "", "database to fetch schema from") - execStmts = flags.Bool("exec-stmts", false, "execute each generated statement against the db specified by url") + flags = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + expr = flags.Bool("expr", false, "generate expressions instead of statements") + num = flags.Int("num", 1, "number of statements / expressions to generate") + url = flags.String("url", "", "database to fetch schema from") + execStmts = flags.Bool("exec-stmts", false, "execute each generated statement against the db specified by url") + schemaPath = flags.String("schema", "", "path containing schema definitions") + prefix = flags.String("prefix", "", "prefix each statement or expression") + smitherOptMap = map[string]sqlsmith.SmitherOption{ - "AvoidConsts": sqlsmith.AvoidConsts(), - "CompareMode": sqlsmith.CompareMode(), - "DisableAggregateFuncs": sqlsmith.DisableAggregateFuncs(), - "DisableCRDBFns": sqlsmith.DisableCRDBFns(), - "DisableCrossJoins": sqlsmith.DisableCrossJoins(), - "DisableDDLs": sqlsmith.DisableDDLs(), - "DisableDecimals": sqlsmith.DisableDecimals(), - "DisableDivision": sqlsmith.DisableDivision(), - "DisableEverything": sqlsmith.DisableEverything(), - "DisableIndexHints": sqlsmith.DisableIndexHints(), - "DisableInsertSelect": sqlsmith.DisableInsertSelect(), - "DisableJoins": sqlsmith.DisableJoins(), - "DisableLimits": sqlsmith.DisableLimits(), - "DisableMutations": sqlsmith.DisableMutations(), - "DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(), - "DisableNondeterministicLimits": sqlsmith.DisableNondeterministicLimits(), - "DisableWindowFuncs": sqlsmith.DisableWindowFuncs(), - "DisableWith": sqlsmith.DisableWith(), - "EnableAlters": sqlsmith.EnableAlters(), - "FavorCommonData": sqlsmith.FavorCommonData(), - "InsUpdOnly": sqlsmith.InsUpdOnly(), + "AvoidConsts": sqlsmith.AvoidConsts(), + "CompareMode": sqlsmith.CompareMode(), + "DisableAggregateFuncs": sqlsmith.DisableAggregateFuncs(), + "DisableCRDBFns": sqlsmith.DisableCRDBFns(), + "DisableCrossJoins": sqlsmith.DisableCrossJoins(), + "DisableDDLs": sqlsmith.DisableDDLs(), + "DisableDecimals": sqlsmith.DisableDecimals(), + "DisableDivision": sqlsmith.DisableDivision(), + "DisableEverything": sqlsmith.DisableEverything(), + "DisableIndexHints": sqlsmith.DisableIndexHints(), + "DisableInsertSelect": sqlsmith.DisableInsertSelect(), + "DisableJoins": sqlsmith.DisableJoins(), + "DisableLimits": sqlsmith.DisableLimits(), + "DisableMutations": sqlsmith.DisableMutations(), + "DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(), + "DisableWindowFuncs": sqlsmith.DisableWindowFuncs(), + "DisableWith": sqlsmith.DisableWith(), + "DisableUDFs": sqlsmith.DisableUDFs(), + "EnableAlters": sqlsmith.EnableAlters(), + "EnableLimits": sqlsmith.EnableLimits(), + "EnableWith": sqlsmith.EnableWith(), + "FavorCommonData": sqlsmith.FavorCommonData(), + "IgnoreFNs": strArgOpt(sqlsmith.IgnoreFNs), + "InsUpdOnly": sqlsmith.InsUpdOnly(), + "MaybeSortOutput": sqlsmith.MaybeSortOutput(), + "MultiRegionDDLs": sqlsmith.MultiRegionDDLs(), + "MutatingMode": sqlsmith.MutatingMode(), + "MutationsOnly": sqlsmith.MutationsOnly(), + "OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(), + "OnlySingleDMLs": sqlsmith.OnlySingleDMLs(), + "OutputSort": sqlsmith.OutputSort(), + "PostgresMode": sqlsmith.PostgresMode(), + "SimpleDatums": sqlsmith.SimpleDatums(), + "SimpleScalarTypes": sqlsmith.SimpleScalarTypes(), + "SimpleNames": sqlsmith.SimpleNames(), + "UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(), + "UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(), + "LowProbabilityWhereClauseWithJoinTables": sqlsmith.LowProbabilityWhereClauseWithJoinTables(), - "MultiRegionDDLs": sqlsmith.MultiRegionDDLs(), - "MutatingMode": sqlsmith.MutatingMode(), - "MutationsOnly": sqlsmith.MutationsOnly(), - "OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(), - "OnlySingleDMLs": sqlsmith.OnlySingleDMLs(), - "OutputSort": sqlsmith.OutputSort(), - "PostgresMode": sqlsmith.PostgresMode(), - "SimpleDatums": sqlsmith.SimpleDatums(), - "SimpleNames": sqlsmith.SimpleNames(), - "UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(), - "UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(), } smitherOpts []string ) @@ -119,9 +139,14 @@ func main() { // Gather our sqlsmith options from command-line arguments. var smitherOpts []sqlsmith.SmitherOption for _, arg := range flags.Args() { - if opt, ok := smitherOptMap[arg]; ok { + argKV := strings.SplitN(arg, "=", 2) + if opt, ok := smitherOptMap[argKV[0]]; ok { fmt.Print("-- ", arg, ": ", opt, "\n") - smitherOpts = append(smitherOpts, opt) + if len(argKV) == 2 { + smitherOpts = append(smitherOpts, opt.(strArgOpt)(argKV[1])) + } else { + smitherOpts = append(smitherOpts, opt) + } } else { fmt.Fprintf(flags.Output(), "unrecognized sqlsmith-go option: %v\n", arg) usage() @@ -146,6 +171,15 @@ func main() { fmt.Println("-- connected to", *url) } + if *schemaPath != "" { + opts, err := parseSchemaDefinition(*schemaPath) + if err != nil { + fmt.Fprintf(flags.Output(), "could not parse schema file %s: %s", *schemaPath, err) + os.Exit(2) + } + smitherOpts = append(smitherOpts, opts...) + } + // Create our smither. smither, err := sqlsmith.NewSmither(db, rng, smitherOpts...) if err != nil { @@ -156,10 +190,15 @@ func main() { // Finally, generate num statements (or expressions). fmt.Println("-- num", *num) + sep := "\n" + if *prefix != "" { + sep = fmt.Sprintf("\n%s\n", *prefix) + } + if *expr { fmt.Println("-- expr") for i := 0; i < *num; i++ { - fmt.Print("\n", smither.GenerateExpr(), "\n") + fmt.Print(sep, smither.GenerateExpr(), "\n") } } else { for i := 0; i < *num; i++ { @@ -168,6 +207,46 @@ func main() { if db != nil && *execStmts { _, _ = db.Exec(stmt) } + fmt.Print(sep, smither.Generate(), ";\n") } } } + +func parseSchemaDefinition(schemaPath string) (opts []sqlsmith.SmitherOption, _ error) { + f, err := os.Open(schemaPath) + if err != nil { + return nil, errors.Wrapf(err, "could not open schema file %s for reading", schemaPath) + } + schema, err := io.ReadAll(f) + if err != nil { + return nil, errors.Wrap(err, "failed to read schema definition data") + } + stmts, err := parser.Parse(string(schema)) + if err != nil { + return nil, errors.Wrap(err, "Could not parse schema definition") + } + semaCtx := tree.MakeSemaContext(nil /* resolver */) + st := cluster.MakeTestingClusterSettings() + wall := timeutil.Now().UnixNano() + parentID := descpb.ID(bootstrap.TestingUserDescID(0)) + for i, s := range stmts { + switch t := s.AST.(type) { + default: + return nil, errors.AssertionFailedf("only CreateTable statements supported, found %T", t) + case *tree.CreateTable: + tableID := descpb.ID(int(parentID) + i + 1) + desc, err := importer.MakeTestingSimpleTableDescriptor( + context.Background(), &semaCtx, st, t, parentID, keys.PublicSchemaID, tableID, importer.NoFKs, wall) + if err != nil { + return nil, errors.Wrapf(err, "failed to create table descriptor for statement %s", t) + } + opts = append(opts, sqlsmith.WithTableDescriptor(t.Table, desc.TableDescriptor)) + } + } + return opts, nil +} + +type strArgOpt func(v string) sqlsmith.SmitherOption + +func (o strArgOpt) Apply(s *sqlsmith.Smither) {} +func (o strArgOpt) String() string { return "" } diff --git a/pkg/internal/sqlsmith/BUILD.bazel b/pkg/internal/sqlsmith/BUILD.bazel index f4b79f610fd1..f6538ff59ebf 100644 --- a/pkg/internal/sqlsmith/BUILD.bazel +++ b/pkg/internal/sqlsmith/BUILD.bazel @@ -42,6 +42,7 @@ go_library( "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", "@com_github_lib_pq//oid", + "@org_golang_x_exp//slices", ], ) diff --git a/pkg/internal/sqlsmith/scalar.go b/pkg/internal/sqlsmith/scalar.go index a48df587bfc8..92c9187d40eb 100644 --- a/pkg/internal/sqlsmith/scalar.go +++ b/pkg/internal/sqlsmith/scalar.go @@ -336,8 +336,12 @@ func makeBinOp(s *Smither, typ *types.T, refs colRefs) (tree.TypedExpr, bool) { if len(ops) == 0 { return nil, false } - n := s.rnd.Intn(len(ops)) - op := ops[n] + op := ops[s.rnd.Intn(len(ops))] + for s.simpleScalarTypes && !(isSimpleSeedType(op.LeftType) && isSimpleSeedType(op.RightType)) { + // We must work harder to pick some other op. + op = ops[s.rnd.Intn(len(ops))] + } + if s.postgres { if ignorePostgresBinOps[binOpTriple{ op.LeftType.Family(), @@ -440,6 +444,12 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx args := make(tree.TypedExprs, 0) for _, argTyp := range fn.overload.Types.Types() { + // Skip this function if we want simple scalar types, but this + // function argument is not. + if s.simpleScalarTypes && !isSimpleSeedType(argTyp) { + return nil, false + } + // Postgres is picky about having Int4 arguments instead of Int8. if s.postgres && argTyp.Family() == types.IntFamily { argTyp = types.Int4 diff --git a/pkg/internal/sqlsmith/schema.go b/pkg/internal/sqlsmith/schema.go index 49679ee96a26..bfab80d0498b 100644 --- a/pkg/internal/sqlsmith/schema.go +++ b/pkg/internal/sqlsmith/schema.go @@ -27,6 +27,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/lib/pq/oid" + "golang.org/x/exp/slices" ) // tableRef represents a table and its columns. @@ -52,6 +53,48 @@ type aliasedTableRef struct { type tableRefs []*tableRef +func WithTableDescriptor(tn tree.TableName, desc descpb.TableDescriptor) SmitherOption { + return option{ + name: fmt.Sprintf("inject table %s", tn.FQString()), + apply: func(s *Smither) { + if tn.SchemaName != "" { + if !slices.ContainsFunc(s.schemas, func(ref *schemaRef) bool { + return ref.SchemaName == tn.SchemaName + }) { + s.schemas = append(s.schemas, &schemaRef{SchemaName: tn.SchemaName}) + } + } + + var cols []*tree.ColumnTableDef + for _, col := range desc.Columns { + column := tree.ColumnTableDef{ + Name: tree.Name(col.Name), + Type: col.Type, + } + if col.Nullable { + column.Nullable.Nullability = tree.Null + } + if col.IsComputed() { + column.Computed.Computed = true + } + cols = append(cols, &column) + } + + s.tables = append(s.tables, &tableRef{ + TableName: &tn, + Columns: cols, + }) + if s.columns == nil { + s.columns = make(map[tree.TableName]map[tree.Name]*tree.ColumnTableDef) + } + s.columns[tn] = make(map[tree.Name]*tree.ColumnTableDef) + for _, col := range cols { + s.columns[tn][col.Name] = col + } + }, + } +} + // ReloadSchemas loads tables from the database. func (s *Smither) ReloadSchemas() error { if s.db == nil { @@ -135,9 +178,13 @@ func (s *Smither) getRandTable() (*aliasedTableRef, bool) { return nil, false } table := s.tables[s.rnd.Intn(len(s.tables))] - var indexFlags tree.IndexFlags + aliased := &aliasedTableRef{ + tableRef: table, + } + if !s.disableIndexHints && s.coin() { indexes := s.getAllIndexesForTableRLocked(*table.TableName) + var indexFlags tree.IndexFlags indexNames := make([]tree.Name, 0, len(indexes)) for _, index := range indexes { if !index.Inverted { @@ -147,10 +194,7 @@ func (s *Smither) getRandTable() (*aliasedTableRef, bool) { if len(indexNames) > 0 { indexFlags.Index = tree.UnrestrictedName(indexNames[s.rnd.Intn(len(indexNames))]) } - } - aliased := &aliasedTableRef{ - tableRef: table, - indexFlags: &indexFlags, + aliased.indexFlags = &indexFlags } return aliased, true } diff --git a/pkg/internal/sqlsmith/sqlsmith.go b/pkg/internal/sqlsmith/sqlsmith.go index 9ae36302ad99..b60c64fb66a6 100644 --- a/pkg/internal/sqlsmith/sqlsmith.go +++ b/pkg/internal/sqlsmith/sqlsmith.go @@ -107,6 +107,7 @@ type Smither struct { ignoreFNs []*regexp.Regexp complexity float64 scalarComplexity float64 + simpleScalarTypes bool unlikelyConstantPredicate bool favorCommonData bool unlikelyRandomNulls bool @@ -210,10 +211,17 @@ func (s *Smither) Generate() string { continue } i = 0 - p, err := prettyCfg.Pretty(stmt) + + printCfg := prettyCfg + fl := tree.FmtParsable + if s.postgres { + printCfg.FmtFlags = tree.FmtPGCatalog + fl = tree.FmtPGCatalog + } + p, err := printCfg.Pretty(stmt) if err != nil { // Use simple printing if pretty-printing fails. - p = tree.AsStringWithFlags(stmt, tree.FmtParsable) + p = tree.AsStringWithFlags(stmt, fl) } return p } @@ -397,6 +405,11 @@ var DisableWith = simpleOption("disable WITH", func(s *Smither) { s.disableWith = true }) +// EnableWith causes the Smither to probabilistically emit WITH clauses. +var EnableWith = simpleOption("enable WITH", func(s *Smither) { + s.disableWith = false +}) + // DisableNondeterministicFns causes the Smither to disable nondeterministic functions. var DisableNondeterministicFns = simpleOption("disable nondeterministic funcs", func(s *Smither) { s.disableNondeterministicFns = true @@ -412,6 +425,11 @@ var SimpleDatums = simpleOption("simple datums", func(s *Smither) { s.simpleDatums = true }) +// SimpleScalarTypes causes the Smither to use simpler scalar types (e.g. avoid Geometry) +var SimpleScalarTypes = simpleOption("simple scalar types", func(s *Smither) { + s.simpleScalarTypes = true +}) + // SimpleNames specifies that complex name generation should be disabled. var SimpleNames = simpleOption("simple names", func(s *Smither) { s.simpleNames = true @@ -457,6 +475,11 @@ var DisableNondeterministicLimits = simpleOption("disable non-deterministic LIMI s.disableNondeterministicLimits = true }) +// EnableLimits causes the Smither to probabilistically emit LIMIT clauses. +var EnableLimits = simpleOption("enable LIMIT", func(s *Smither) { + s.disableLimits = false +}) + // AvoidConsts causes the Smither to prefer column references over generating // constants. var AvoidConsts = simpleOption("avoid consts", func(s *Smither) { @@ -478,6 +501,11 @@ var OutputSort = simpleOption("output sort", func(s *Smither) { s.outputSort = true }) +// MaybeSortOutput probabilistically adds ORDER by clause +var MaybeSortOutput = simpleOption("maybe output sort", func(s *Smither) { + s.outputSort = false +}) + // UnlikelyConstantPredicate causes the Smither to make generation of constant // WHERE clause, ON clause or HAVING clause predicates which only contain // constant boolean expressions such as `TRUE` or `FALSE OR TRUE` much less diff --git a/pkg/internal/sqlsmith/type.go b/pkg/internal/sqlsmith/type.go index 763b7bf98265..41edf298ba66 100644 --- a/pkg/internal/sqlsmith/type.go +++ b/pkg/internal/sqlsmith/type.go @@ -55,10 +55,45 @@ func (s *Smither) pickAnyType(typ *types.T) *types.T { return typ } +var simpleScalarTypes = func() (typs []*types.T) { + for _, t := range types.Scalar { + switch t { + case types.Box2D, types.Geography, types.Geometry, types.INet, types.PGLSN, + types.RefCursor, types.TSQuery, types.TSVector: + // Skip fancy types. + default: + typs = append(typs, t) + } + } + return typs +}() + +func isSimpleSeedType(typ *types.T) bool { + switch typ.Family() { + case types.BoolFamily, types.IntFamily, types.DecimalFamily, types.FloatFamily, types.StringFamily, + types.BytesFamily, types.DateFamily, types.TimestampFamily, types.IntervalFamily, types.TimeFamily, types.TimeTZFamily: + return true + case types.ArrayFamily: + return isSimpleSeedType(typ.ArrayContents()) + case types.TupleFamily: + for _, t := range typ.TupleContents() { + if !isSimpleSeedType(t) { + return false + } + } + return true + default: + return false + } +} + func (s *Smither) randScalarType() *types.T { s.lock.RLock() defer s.lock.RUnlock() scalarTypes := types.Scalar + if s.simpleScalarTypes { + scalarTypes = simpleScalarTypes + } if s.types != nil { scalarTypes = s.types.scalarTypes } @@ -80,6 +115,9 @@ func (s *Smither) randScalarType() *types.T { func (s *Smither) isScalarType(t *types.T) bool { s.lock.AssertRHeld() scalarTypes := types.Scalar + if s.simpleScalarTypes { + scalarTypes = simpleScalarTypes + } if s.types != nil { scalarTypes = s.types.scalarTypes } @@ -110,6 +148,9 @@ func (s *Smither) randType() *types.T { // which compare CRDB behavior to Postgres. continue } + if s.simpleScalarTypes && !isSimpleSeedType(typ) { + continue + } break } return typ diff --git a/pkg/sql/sem/tree/pretty.go b/pkg/sql/sem/tree/pretty.go index 7d1586b8c1c7..3bf10774a91e 100644 --- a/pkg/sql/sem/tree/pretty.go +++ b/pkg/sql/sem/tree/pretty.go @@ -62,6 +62,8 @@ type PrettyCfg struct { JSONFmt bool // ValueRedaction, when set, surrounds literal values with redaction markers. ValueRedaction bool + // FmtFlags specifies FmtFlags to use when formatting expressions. + FmtFlags FmtFlags } // DefaultPrettyCfg returns a PrettyCfg with the default @@ -184,6 +186,10 @@ func (p *PrettyCfg) docAsString(f NodeFormatter) pretty.Doc { } func (p *PrettyCfg) fmtFlags() FmtFlags { + if p.FmtFlags != FmtFlags(0) { + return p.FmtFlags + } + prettyFlags := FmtShowPasswords | FmtParsable | FmtTagDollarQuotes if p.ValueRedaction { prettyFlags |= FmtMarkRedactionNode | FmtOmitNameRedaction