Skip to content

Commit

Permalink
sqlsmith: Various usability and bug fixes.
Browse files Browse the repository at this point in the history
This PR adds various command line flags and
sqlsmith options to improve usability and correctness:
 * Add `--schema` command line argument to sqlsmith
   so that the running database instance is not required
   in order to populate existing table information.
 * Add `--prefix` command line argument which will add
   prefix to every statement/expression generated by sqlsmith
   (this can be used for example to generalte sqllogic format)
 * Add a new smither option `SimpleScalarTypes` which eschews
   "complex types" -- such as GEOMETRY, GEOGRAPHY, and other less
   common types.
 * Fix a bug where `DisableIndexHints` option was not
   respected (all tables had index flags).
 * Fix a bug where pretty printing (query or expression) did
   not respect `PostgresMode`, thus producing queries that are
   not compatible w/ postgres (e.g. they always included type
   annotations).

Epic: None
Release note: None
  • Loading branch information
Yevgeniy Miretskiy authored and michae2 committed May 2, 2024
1 parent 592e709 commit 19d693f
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 49 deletions.
8 changes: 8 additions & 0 deletions pkg/cmd/smith/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ go_library(
visibility = ["//visibility:private"],
deps = [
"//pkg/internal/sqlsmith",
"//pkg/keys",
"//pkg/settings/cluster",
"//pkg/sql/catalog/bootstrap",
"//pkg/sql/catalog/descpb",
"//pkg/sql/importer",
"//pkg/sql/parser",
"//pkg/sql/sem/tree",
"//pkg/util/randutil",
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
],
)
Expand Down
159 changes: 119 additions & 40 deletions pkg/cmd/smith/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,25 @@
package main

import (
"context"
gosql "database/sql"
"flag"
"fmt"
"io"
"os"
"sort"
"strings"

"github.com/cockroachdb/cockroach/pkg/internal/sqlsmith"
"github.com/cockroachdb/cockroach/pkg/keys"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/bootstrap"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/importer"
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/errors"
)

Expand All @@ -43,45 +54,54 @@ Options:
`

var (
flags = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
expr = flags.Bool("expr", false, "generate expressions instead of statements")
num = flags.Int("num", 1, "number of statements / expressions to generate")
url = flags.String("url", "", "database to fetch schema from")
execStmts = flags.Bool("exec-stmts", false, "execute each generated statement against the db specified by url")
flags = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
expr = flags.Bool("expr", false, "generate expressions instead of statements")
num = flags.Int("num", 1, "number of statements / expressions to generate")
url = flags.String("url", "", "database to fetch schema from")
execStmts = flags.Bool("exec-stmts", false, "execute each generated statement against the db specified by url")
schemaPath = flags.String("schema", "", "path containing schema definitions")
prefix = flags.String("prefix", "", "prefix each statement or expression")

smitherOptMap = map[string]sqlsmith.SmitherOption{
"AvoidConsts": sqlsmith.AvoidConsts(),
"CompareMode": sqlsmith.CompareMode(),
"DisableAggregateFuncs": sqlsmith.DisableAggregateFuncs(),
"DisableCRDBFns": sqlsmith.DisableCRDBFns(),
"DisableCrossJoins": sqlsmith.DisableCrossJoins(),
"DisableDDLs": sqlsmith.DisableDDLs(),
"DisableDecimals": sqlsmith.DisableDecimals(),
"DisableDivision": sqlsmith.DisableDivision(),
"DisableEverything": sqlsmith.DisableEverything(),
"DisableIndexHints": sqlsmith.DisableIndexHints(),
"DisableInsertSelect": sqlsmith.DisableInsertSelect(),
"DisableJoins": sqlsmith.DisableJoins(),
"DisableLimits": sqlsmith.DisableLimits(),
"DisableMutations": sqlsmith.DisableMutations(),
"DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(),
"DisableNondeterministicLimits": sqlsmith.DisableNondeterministicLimits(),
"DisableWindowFuncs": sqlsmith.DisableWindowFuncs(),
"DisableWith": sqlsmith.DisableWith(),
"EnableAlters": sqlsmith.EnableAlters(),
"FavorCommonData": sqlsmith.FavorCommonData(),
"InsUpdOnly": sqlsmith.InsUpdOnly(),
"AvoidConsts": sqlsmith.AvoidConsts(),
"CompareMode": sqlsmith.CompareMode(),
"DisableAggregateFuncs": sqlsmith.DisableAggregateFuncs(),
"DisableCRDBFns": sqlsmith.DisableCRDBFns(),
"DisableCrossJoins": sqlsmith.DisableCrossJoins(),
"DisableDDLs": sqlsmith.DisableDDLs(),
"DisableDecimals": sqlsmith.DisableDecimals(),
"DisableDivision": sqlsmith.DisableDivision(),
"DisableEverything": sqlsmith.DisableEverything(),
"DisableIndexHints": sqlsmith.DisableIndexHints(),
"DisableInsertSelect": sqlsmith.DisableInsertSelect(),
"DisableJoins": sqlsmith.DisableJoins(),
"DisableLimits": sqlsmith.DisableLimits(),
"DisableMutations": sqlsmith.DisableMutations(),
"DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(),
"DisableWindowFuncs": sqlsmith.DisableWindowFuncs(),
"DisableWith": sqlsmith.DisableWith(),
"DisableUDFs": sqlsmith.DisableUDFs(),
"EnableAlters": sqlsmith.EnableAlters(),
"EnableLimits": sqlsmith.EnableLimits(),
"EnableWith": sqlsmith.EnableWith(),
"FavorCommonData": sqlsmith.FavorCommonData(),
"IgnoreFNs": strArgOpt(sqlsmith.IgnoreFNs),
"InsUpdOnly": sqlsmith.InsUpdOnly(),
"MaybeSortOutput": sqlsmith.MaybeSortOutput(),
"MultiRegionDDLs": sqlsmith.MultiRegionDDLs(),
"MutatingMode": sqlsmith.MutatingMode(),
"MutationsOnly": sqlsmith.MutationsOnly(),
"OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(),
"OnlySingleDMLs": sqlsmith.OnlySingleDMLs(),
"OutputSort": sqlsmith.OutputSort(),
"PostgresMode": sqlsmith.PostgresMode(),
"SimpleDatums": sqlsmith.SimpleDatums(),
"SimpleScalarTypes": sqlsmith.SimpleScalarTypes(),
"SimpleNames": sqlsmith.SimpleNames(),
"UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(),
"UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(),

"LowProbabilityWhereClauseWithJoinTables": sqlsmith.LowProbabilityWhereClauseWithJoinTables(),
"MultiRegionDDLs": sqlsmith.MultiRegionDDLs(),
"MutatingMode": sqlsmith.MutatingMode(),
"MutationsOnly": sqlsmith.MutationsOnly(),
"OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(),
"OnlySingleDMLs": sqlsmith.OnlySingleDMLs(),
"OutputSort": sqlsmith.OutputSort(),
"PostgresMode": sqlsmith.PostgresMode(),
"SimpleDatums": sqlsmith.SimpleDatums(),
"SimpleNames": sqlsmith.SimpleNames(),
"UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(),
"UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(),
}
smitherOpts []string
)
Expand Down Expand Up @@ -119,9 +139,14 @@ func main() {
// Gather our sqlsmith options from command-line arguments.
var smitherOpts []sqlsmith.SmitherOption
for _, arg := range flags.Args() {
if opt, ok := smitherOptMap[arg]; ok {
argKV := strings.SplitN(arg, "=", 2)
if opt, ok := smitherOptMap[argKV[0]]; ok {
fmt.Print("-- ", arg, ": ", opt, "\n")
smitherOpts = append(smitherOpts, opt)
if len(argKV) == 2 {
smitherOpts = append(smitherOpts, opt.(strArgOpt)(argKV[1]))
} else {
smitherOpts = append(smitherOpts, opt)
}
} else {
fmt.Fprintf(flags.Output(), "unrecognized sqlsmith-go option: %v\n", arg)
usage()
Expand All @@ -146,6 +171,15 @@ func main() {
fmt.Println("-- connected to", *url)
}

if *schemaPath != "" {
opts, err := parseSchemaDefinition(*schemaPath)
if err != nil {
fmt.Fprintf(flags.Output(), "could not parse schema file %s: %s", *schemaPath, err)
os.Exit(2)
}
smitherOpts = append(smitherOpts, opts...)
}

// Create our smither.
smither, err := sqlsmith.NewSmither(db, rng, smitherOpts...)
if err != nil {
Expand All @@ -156,10 +190,15 @@ func main() {

// Finally, generate num statements (or expressions).
fmt.Println("-- num", *num)
sep := "\n"
if *prefix != "" {
sep = fmt.Sprintf("\n%s\n", *prefix)
}

if *expr {
fmt.Println("-- expr")
for i := 0; i < *num; i++ {
fmt.Print("\n", smither.GenerateExpr(), "\n")
fmt.Print(sep, smither.GenerateExpr(), "\n")
}
} else {
for i := 0; i < *num; i++ {
Expand All @@ -168,6 +207,46 @@ func main() {
if db != nil && *execStmts {
_, _ = db.Exec(stmt)
}
fmt.Print(sep, smither.Generate(), ";\n")
}
}
}

func parseSchemaDefinition(schemaPath string) (opts []sqlsmith.SmitherOption, _ error) {
f, err := os.Open(schemaPath)
if err != nil {
return nil, errors.Wrapf(err, "could not open schema file %s for reading", schemaPath)
}
schema, err := io.ReadAll(f)
if err != nil {
return nil, errors.Wrap(err, "failed to read schema definition data")
}
stmts, err := parser.Parse(string(schema))
if err != nil {
return nil, errors.Wrap(err, "Could not parse schema definition")
}
semaCtx := tree.MakeSemaContext(nil /* resolver */)
st := cluster.MakeTestingClusterSettings()
wall := timeutil.Now().UnixNano()
parentID := descpb.ID(bootstrap.TestingUserDescID(0))
for i, s := range stmts {
switch t := s.AST.(type) {
default:
return nil, errors.AssertionFailedf("only CreateTable statements supported, found %T", t)
case *tree.CreateTable:
tableID := descpb.ID(int(parentID) + i + 1)
desc, err := importer.MakeTestingSimpleTableDescriptor(
context.Background(), &semaCtx, st, t, parentID, keys.PublicSchemaID, tableID, importer.NoFKs, wall)
if err != nil {
return nil, errors.Wrapf(err, "failed to create table descriptor for statement %s", t)
}
opts = append(opts, sqlsmith.WithTableDescriptor(t.Table, desc.TableDescriptor))
}
}
return opts, nil
}

type strArgOpt func(v string) sqlsmith.SmitherOption

func (o strArgOpt) Apply(s *sqlsmith.Smither) {}
func (o strArgOpt) String() string { return "" }
1 change: 1 addition & 0 deletions pkg/internal/sqlsmith/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ go_library(
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_lib_pq//oid",
"@org_golang_x_exp//slices",
],
)

Expand Down
14 changes: 12 additions & 2 deletions pkg/internal/sqlsmith/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,12 @@ func makeBinOp(s *Smither, typ *types.T, refs colRefs) (tree.TypedExpr, bool) {
if len(ops) == 0 {
return nil, false
}
n := s.rnd.Intn(len(ops))
op := ops[n]
op := ops[s.rnd.Intn(len(ops))]
for s.simpleScalarTypes && !(isSimpleSeedType(op.LeftType) && isSimpleSeedType(op.RightType)) {
// We must work harder to pick some other op.
op = ops[s.rnd.Intn(len(ops))]
}

if s.postgres {
if ignorePostgresBinOps[binOpTriple{
op.LeftType.Family(),
Expand Down Expand Up @@ -440,6 +444,12 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx

args := make(tree.TypedExprs, 0)
for _, argTyp := range fn.overload.Types.Types() {
// Skip this function if we want simple scalar types, but this
// function argument is not.
if s.simpleScalarTypes && !isSimpleSeedType(argTyp) {
return nil, false
}

// Postgres is picky about having Int4 arguments instead of Int8.
if s.postgres && argTyp.Family() == types.IntFamily {
argTyp = types.Int4
Expand Down
54 changes: 49 additions & 5 deletions pkg/internal/sqlsmith/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/lib/pq/oid"
"golang.org/x/exp/slices"
)

// tableRef represents a table and its columns.
Expand All @@ -52,6 +53,48 @@ type aliasedTableRef struct {

type tableRefs []*tableRef

func WithTableDescriptor(tn tree.TableName, desc descpb.TableDescriptor) SmitherOption {
return option{
name: fmt.Sprintf("inject table %s", tn.FQString()),
apply: func(s *Smither) {
if tn.SchemaName != "" {
if !slices.ContainsFunc(s.schemas, func(ref *schemaRef) bool {
return ref.SchemaName == tn.SchemaName
}) {
s.schemas = append(s.schemas, &schemaRef{SchemaName: tn.SchemaName})
}
}

var cols []*tree.ColumnTableDef
for _, col := range desc.Columns {
column := tree.ColumnTableDef{
Name: tree.Name(col.Name),
Type: col.Type,
}
if col.Nullable {
column.Nullable.Nullability = tree.Null
}
if col.IsComputed() {
column.Computed.Computed = true
}
cols = append(cols, &column)
}

s.tables = append(s.tables, &tableRef{
TableName: &tn,
Columns: cols,
})
if s.columns == nil {
s.columns = make(map[tree.TableName]map[tree.Name]*tree.ColumnTableDef)
}
s.columns[tn] = make(map[tree.Name]*tree.ColumnTableDef)
for _, col := range cols {
s.columns[tn][col.Name] = col
}
},
}
}

// ReloadSchemas loads tables from the database.
func (s *Smither) ReloadSchemas() error {
if s.db == nil {
Expand Down Expand Up @@ -135,9 +178,13 @@ func (s *Smither) getRandTable() (*aliasedTableRef, bool) {
return nil, false
}
table := s.tables[s.rnd.Intn(len(s.tables))]
var indexFlags tree.IndexFlags
aliased := &aliasedTableRef{
tableRef: table,
}

if !s.disableIndexHints && s.coin() {
indexes := s.getAllIndexesForTableRLocked(*table.TableName)
var indexFlags tree.IndexFlags
indexNames := make([]tree.Name, 0, len(indexes))
for _, index := range indexes {
if !index.Inverted {
Expand All @@ -147,10 +194,7 @@ func (s *Smither) getRandTable() (*aliasedTableRef, bool) {
if len(indexNames) > 0 {
indexFlags.Index = tree.UnrestrictedName(indexNames[s.rnd.Intn(len(indexNames))])
}
}
aliased := &aliasedTableRef{
tableRef: table,
indexFlags: &indexFlags,
aliased.indexFlags = &indexFlags
}
return aliased, true
}
Expand Down

0 comments on commit 19d693f

Please sign in to comment.