Skip to content

Commit

Permalink
Merge #120504
Browse files Browse the repository at this point in the history
120504: sqlsmith: Various usability and bug fixes. r=michae2 a=miretskiy

This PR adds various command line flags and
sqlsmith options to improve usability and correctness:
 * Add `--schema` command line argument to sqlsmith so that the running database instance is not required in order to populate existing table information.
 * Add `--prefix` command line argument which will add prefix to every statement/expression generated by sqlsmith (this can be used for example to generalte sqllogic format)
 * Add a new smither option `SimpleScalarTypes` which eschews "complex types" -- such as GEOMETRY, GEOGRAPHY, and other less common types.
 * Fix a bug where `DisableIndexHints` option was not respected (all tables had index flags).
 * Fix a bug where pretty printing (query or expression) did not respect `PostgresMode`, thus producing queries that are not compatible w/ postgres (e.g. they always included type annotations).

Epic: None 
Release note: None

Co-authored-by: Yevgeniy Miretskiy <yevgeniy@datadoghq.com>
  • Loading branch information
craig[bot] and Yevgeniy Miretskiy committed May 2, 2024
2 parents 1f6e966 + 19d693f commit d15210e
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 49 deletions.
8 changes: 8 additions & 0 deletions pkg/cmd/smith/BUILD.bazel
Expand Up @@ -7,7 +7,15 @@ go_library(
visibility = ["//visibility:private"],
deps = [
"//pkg/internal/sqlsmith",
"//pkg/keys",
"//pkg/settings/cluster",
"//pkg/sql/catalog/bootstrap",
"//pkg/sql/catalog/descpb",
"//pkg/sql/importer",
"//pkg/sql/parser",
"//pkg/sql/sem/tree",
"//pkg/util/randutil",
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
],
)
Expand Down
159 changes: 119 additions & 40 deletions pkg/cmd/smith/main.go
Expand Up @@ -11,14 +11,25 @@
package main

import (
"context"
gosql "database/sql"
"flag"
"fmt"
"io"
"os"
"sort"
"strings"

"github.com/cockroachdb/cockroach/pkg/internal/sqlsmith"
"github.com/cockroachdb/cockroach/pkg/keys"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/bootstrap"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/importer"
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/errors"
)

Expand All @@ -43,45 +54,54 @@ Options:
`

var (
flags = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
expr = flags.Bool("expr", false, "generate expressions instead of statements")
num = flags.Int("num", 1, "number of statements / expressions to generate")
url = flags.String("url", "", "database to fetch schema from")
execStmts = flags.Bool("exec-stmts", false, "execute each generated statement against the db specified by url")
flags = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
expr = flags.Bool("expr", false, "generate expressions instead of statements")
num = flags.Int("num", 1, "number of statements / expressions to generate")
url = flags.String("url", "", "database to fetch schema from")
execStmts = flags.Bool("exec-stmts", false, "execute each generated statement against the db specified by url")
schemaPath = flags.String("schema", "", "path containing schema definitions")
prefix = flags.String("prefix", "", "prefix each statement or expression")

smitherOptMap = map[string]sqlsmith.SmitherOption{
"AvoidConsts": sqlsmith.AvoidConsts(),
"CompareMode": sqlsmith.CompareMode(),
"DisableAggregateFuncs": sqlsmith.DisableAggregateFuncs(),
"DisableCRDBFns": sqlsmith.DisableCRDBFns(),
"DisableCrossJoins": sqlsmith.DisableCrossJoins(),
"DisableDDLs": sqlsmith.DisableDDLs(),
"DisableDecimals": sqlsmith.DisableDecimals(),
"DisableDivision": sqlsmith.DisableDivision(),
"DisableEverything": sqlsmith.DisableEverything(),
"DisableIndexHints": sqlsmith.DisableIndexHints(),
"DisableInsertSelect": sqlsmith.DisableInsertSelect(),
"DisableJoins": sqlsmith.DisableJoins(),
"DisableLimits": sqlsmith.DisableLimits(),
"DisableMutations": sqlsmith.DisableMutations(),
"DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(),
"DisableNondeterministicLimits": sqlsmith.DisableNondeterministicLimits(),
"DisableWindowFuncs": sqlsmith.DisableWindowFuncs(),
"DisableWith": sqlsmith.DisableWith(),
"EnableAlters": sqlsmith.EnableAlters(),
"FavorCommonData": sqlsmith.FavorCommonData(),
"InsUpdOnly": sqlsmith.InsUpdOnly(),
"AvoidConsts": sqlsmith.AvoidConsts(),
"CompareMode": sqlsmith.CompareMode(),
"DisableAggregateFuncs": sqlsmith.DisableAggregateFuncs(),
"DisableCRDBFns": sqlsmith.DisableCRDBFns(),
"DisableCrossJoins": sqlsmith.DisableCrossJoins(),
"DisableDDLs": sqlsmith.DisableDDLs(),
"DisableDecimals": sqlsmith.DisableDecimals(),
"DisableDivision": sqlsmith.DisableDivision(),
"DisableEverything": sqlsmith.DisableEverything(),
"DisableIndexHints": sqlsmith.DisableIndexHints(),
"DisableInsertSelect": sqlsmith.DisableInsertSelect(),
"DisableJoins": sqlsmith.DisableJoins(),
"DisableLimits": sqlsmith.DisableLimits(),
"DisableMutations": sqlsmith.DisableMutations(),
"DisableNondeterministicFns": sqlsmith.DisableNondeterministicFns(),
"DisableWindowFuncs": sqlsmith.DisableWindowFuncs(),
"DisableWith": sqlsmith.DisableWith(),
"DisableUDFs": sqlsmith.DisableUDFs(),
"EnableAlters": sqlsmith.EnableAlters(),
"EnableLimits": sqlsmith.EnableLimits(),
"EnableWith": sqlsmith.EnableWith(),
"FavorCommonData": sqlsmith.FavorCommonData(),
"IgnoreFNs": strArgOpt(sqlsmith.IgnoreFNs),
"InsUpdOnly": sqlsmith.InsUpdOnly(),
"MaybeSortOutput": sqlsmith.MaybeSortOutput(),
"MultiRegionDDLs": sqlsmith.MultiRegionDDLs(),
"MutatingMode": sqlsmith.MutatingMode(),
"MutationsOnly": sqlsmith.MutationsOnly(),
"OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(),
"OnlySingleDMLs": sqlsmith.OnlySingleDMLs(),
"OutputSort": sqlsmith.OutputSort(),
"PostgresMode": sqlsmith.PostgresMode(),
"SimpleDatums": sqlsmith.SimpleDatums(),
"SimpleScalarTypes": sqlsmith.SimpleScalarTypes(),
"SimpleNames": sqlsmith.SimpleNames(),
"UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(),
"UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(),

"LowProbabilityWhereClauseWithJoinTables": sqlsmith.LowProbabilityWhereClauseWithJoinTables(),
"MultiRegionDDLs": sqlsmith.MultiRegionDDLs(),
"MutatingMode": sqlsmith.MutatingMode(),
"MutationsOnly": sqlsmith.MutationsOnly(),
"OnlyNoDropDDLs": sqlsmith.OnlyNoDropDDLs(),
"OnlySingleDMLs": sqlsmith.OnlySingleDMLs(),
"OutputSort": sqlsmith.OutputSort(),
"PostgresMode": sqlsmith.PostgresMode(),
"SimpleDatums": sqlsmith.SimpleDatums(),
"SimpleNames": sqlsmith.SimpleNames(),
"UnlikelyConstantPredicate": sqlsmith.UnlikelyConstantPredicate(),
"UnlikelyRandomNulls": sqlsmith.UnlikelyRandomNulls(),
}
smitherOpts []string
)
Expand Down Expand Up @@ -119,9 +139,14 @@ func main() {
// Gather our sqlsmith options from command-line arguments.
var smitherOpts []sqlsmith.SmitherOption
for _, arg := range flags.Args() {
if opt, ok := smitherOptMap[arg]; ok {
argKV := strings.SplitN(arg, "=", 2)
if opt, ok := smitherOptMap[argKV[0]]; ok {
fmt.Print("-- ", arg, ": ", opt, "\n")
smitherOpts = append(smitherOpts, opt)
if len(argKV) == 2 {
smitherOpts = append(smitherOpts, opt.(strArgOpt)(argKV[1]))
} else {
smitherOpts = append(smitherOpts, opt)
}
} else {
fmt.Fprintf(flags.Output(), "unrecognized sqlsmith-go option: %v\n", arg)
usage()
Expand All @@ -146,6 +171,15 @@ func main() {
fmt.Println("-- connected to", *url)
}

if *schemaPath != "" {
opts, err := parseSchemaDefinition(*schemaPath)
if err != nil {
fmt.Fprintf(flags.Output(), "could not parse schema file %s: %s", *schemaPath, err)
os.Exit(2)
}
smitherOpts = append(smitherOpts, opts...)
}

// Create our smither.
smither, err := sqlsmith.NewSmither(db, rng, smitherOpts...)
if err != nil {
Expand All @@ -156,10 +190,15 @@ func main() {

// Finally, generate num statements (or expressions).
fmt.Println("-- num", *num)
sep := "\n"
if *prefix != "" {
sep = fmt.Sprintf("\n%s\n", *prefix)
}

if *expr {
fmt.Println("-- expr")
for i := 0; i < *num; i++ {
fmt.Print("\n", smither.GenerateExpr(), "\n")
fmt.Print(sep, smither.GenerateExpr(), "\n")
}
} else {
for i := 0; i < *num; i++ {
Expand All @@ -168,6 +207,46 @@ func main() {
if db != nil && *execStmts {
_, _ = db.Exec(stmt)
}
fmt.Print(sep, smither.Generate(), ";\n")
}
}
}

func parseSchemaDefinition(schemaPath string) (opts []sqlsmith.SmitherOption, _ error) {
f, err := os.Open(schemaPath)
if err != nil {
return nil, errors.Wrapf(err, "could not open schema file %s for reading", schemaPath)
}
schema, err := io.ReadAll(f)
if err != nil {
return nil, errors.Wrap(err, "failed to read schema definition data")
}
stmts, err := parser.Parse(string(schema))
if err != nil {
return nil, errors.Wrap(err, "Could not parse schema definition")
}
semaCtx := tree.MakeSemaContext(nil /* resolver */)
st := cluster.MakeTestingClusterSettings()
wall := timeutil.Now().UnixNano()
parentID := descpb.ID(bootstrap.TestingUserDescID(0))
for i, s := range stmts {
switch t := s.AST.(type) {
default:
return nil, errors.AssertionFailedf("only CreateTable statements supported, found %T", t)
case *tree.CreateTable:
tableID := descpb.ID(int(parentID) + i + 1)
desc, err := importer.MakeTestingSimpleTableDescriptor(
context.Background(), &semaCtx, st, t, parentID, keys.PublicSchemaID, tableID, importer.NoFKs, wall)
if err != nil {
return nil, errors.Wrapf(err, "failed to create table descriptor for statement %s", t)
}
opts = append(opts, sqlsmith.WithTableDescriptor(t.Table, desc.TableDescriptor))
}
}
return opts, nil
}

type strArgOpt func(v string) sqlsmith.SmitherOption

func (o strArgOpt) Apply(s *sqlsmith.Smither) {}
func (o strArgOpt) String() string { return "" }
1 change: 1 addition & 0 deletions pkg/internal/sqlsmith/BUILD.bazel
Expand Up @@ -42,6 +42,7 @@ go_library(
"//pkg/util/timeutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_lib_pq//oid",
"@org_golang_x_exp//slices",
],
)

Expand Down
14 changes: 12 additions & 2 deletions pkg/internal/sqlsmith/scalar.go
Expand Up @@ -336,8 +336,12 @@ func makeBinOp(s *Smither, typ *types.T, refs colRefs) (tree.TypedExpr, bool) {
if len(ops) == 0 {
return nil, false
}
n := s.rnd.Intn(len(ops))
op := ops[n]
op := ops[s.rnd.Intn(len(ops))]
for s.simpleScalarTypes && !(isSimpleSeedType(op.LeftType) && isSimpleSeedType(op.RightType)) {
// We must work harder to pick some other op.
op = ops[s.rnd.Intn(len(ops))]
}

if s.postgres {
if ignorePostgresBinOps[binOpTriple{
op.LeftType.Family(),
Expand Down Expand Up @@ -440,6 +444,12 @@ func makeFunc(s *Smither, ctx Context, typ *types.T, refs colRefs) (tree.TypedEx

args := make(tree.TypedExprs, 0)
for _, argTyp := range fn.overload.Types.Types() {
// Skip this function if we want simple scalar types, but this
// function argument is not.
if s.simpleScalarTypes && !isSimpleSeedType(argTyp) {
return nil, false
}

// Postgres is picky about having Int4 arguments instead of Int8.
if s.postgres && argTyp.Family() == types.IntFamily {
argTyp = types.Int4
Expand Down
54 changes: 49 additions & 5 deletions pkg/internal/sqlsmith/schema.go
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/lib/pq/oid"
"golang.org/x/exp/slices"
)

// tableRef represents a table and its columns.
Expand All @@ -52,6 +53,48 @@ type aliasedTableRef struct {

type tableRefs []*tableRef

func WithTableDescriptor(tn tree.TableName, desc descpb.TableDescriptor) SmitherOption {
return option{
name: fmt.Sprintf("inject table %s", tn.FQString()),
apply: func(s *Smither) {
if tn.SchemaName != "" {
if !slices.ContainsFunc(s.schemas, func(ref *schemaRef) bool {
return ref.SchemaName == tn.SchemaName
}) {
s.schemas = append(s.schemas, &schemaRef{SchemaName: tn.SchemaName})
}
}

var cols []*tree.ColumnTableDef
for _, col := range desc.Columns {
column := tree.ColumnTableDef{
Name: tree.Name(col.Name),
Type: col.Type,
}
if col.Nullable {
column.Nullable.Nullability = tree.Null
}
if col.IsComputed() {
column.Computed.Computed = true
}
cols = append(cols, &column)
}

s.tables = append(s.tables, &tableRef{
TableName: &tn,
Columns: cols,
})
if s.columns == nil {
s.columns = make(map[tree.TableName]map[tree.Name]*tree.ColumnTableDef)
}
s.columns[tn] = make(map[tree.Name]*tree.ColumnTableDef)
for _, col := range cols {
s.columns[tn][col.Name] = col
}
},
}
}

// ReloadSchemas loads tables from the database.
func (s *Smither) ReloadSchemas() error {
if s.db == nil {
Expand Down Expand Up @@ -135,9 +178,13 @@ func (s *Smither) getRandTable() (*aliasedTableRef, bool) {
return nil, false
}
table := s.tables[s.rnd.Intn(len(s.tables))]
var indexFlags tree.IndexFlags
aliased := &aliasedTableRef{
tableRef: table,
}

if !s.disableIndexHints && s.coin() {
indexes := s.getAllIndexesForTableRLocked(*table.TableName)
var indexFlags tree.IndexFlags
indexNames := make([]tree.Name, 0, len(indexes))
for _, index := range indexes {
if !index.Inverted {
Expand All @@ -147,10 +194,7 @@ func (s *Smither) getRandTable() (*aliasedTableRef, bool) {
if len(indexNames) > 0 {
indexFlags.Index = tree.UnrestrictedName(indexNames[s.rnd.Intn(len(indexNames))])
}
}
aliased := &aliasedTableRef{
tableRef: table,
indexFlags: &indexFlags,
aliased.indexFlags = &indexFlags
}
return aliased, true
}
Expand Down

0 comments on commit d15210e

Please sign in to comment.