Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add table statistics and automatic Join pushdown for SQL Server connector #11637

Merged
merged 3 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions plugin/trino-sqlserver/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@

<properties>
<air.main.basedir>${project.parent.basedir}</air.main.basedir>

<!--
Project's default for air.test.parallel is 'methods'. By design, 'classes' makes TestNG run tests from one class in a single thread.
As a side effect, it prevents TestNG from initializing multiple test instances upfront, which happens with 'methods'.
A potential downside can be long tail single-threaded execution of a single long test class.
TODO (https://github.com/trinodb/trino/issues/11294) remove when we upgrade to surefire with https://issues.apache.org/jira/browse/SUREFIRE-1967
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-->
<air.test.parallel>classes</air.test.parallel>
</properties>

<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
import io.trino.plugin.jdbc.JdbcSortItem;
import io.trino.plugin.jdbc.JdbcSplit;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongReadFunction;
import io.trino.plugin.jdbc.LongWriteFunction;
import io.trino.plugin.jdbc.ObjectReadFunction;
import io.trino.plugin.jdbc.ObjectWriteFunction;
import io.trino.plugin.jdbc.PreparedQuery;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.SliceWriteFunction;
Expand All @@ -58,7 +60,13 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
Expand All @@ -77,6 +85,7 @@

import javax.inject.Inject;

import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand All @@ -89,6 +98,7 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -97,11 +107,15 @@
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.collect.MoreCollectors.toOptional;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static com.microsoft.sqlserver.jdbc.SQLServerConnection.TRANSACTION_SNAPSHOT;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware;
import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction;
Expand Down Expand Up @@ -191,18 +205,27 @@ public class SqlServerClient
.maximumSize(1)
.expireAfterWrite(ofMinutes(5)));

private final boolean statisticsEnabled;

private final ConnectorExpressionRewriter<String> connectorExpressionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, String> aggregateFunctionRewriter;

private static final int MAX_SUPPORTED_TEMPORAL_PRECISION = 7;

@Inject
public SqlServerClient(BaseJdbcConfig config, SqlServerConfig sqlServerConfig, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping)
public SqlServerClient(
BaseJdbcConfig config,
SqlServerConfig sqlServerConfig,
JdbcStatisticsConfig statisticsConfig,
ConnectionFactory connectionFactory,
QueryBuilder queryBuilder,
IdentifierMapping identifierMapping)
{
super(config, "\"", connectionFactory, queryBuilder, identifierMapping);

requireNonNull(sqlServerConfig, "sqlServerConfig is null");
snapshotIsolationDisabled = sqlServerConfig.isSnapshotIsolationDisabled();
this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled();

this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
Expand Down Expand Up @@ -452,6 +475,165 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName());
}

@Override
public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain<ColumnHandle> tupleDomain)
{
if (!statisticsEnabled) {
return TableStatistics.empty();
}
if (!handle.isNamedRelation()) {
return TableStatistics.empty();
}
try {
return readTableStatistics(session, handle);
}
catch (SQLException | RuntimeException e) {
throwIfInstanceOf(e, TrinoException.class);
throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e);
}
}

private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table)
throws SQLException
{
checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table);

try (Connection connection = connectionFactory.openConnection(session);
Handle handle = Jdbi.open(connection)) {
String catalog = table.getCatalogName();
String schema = table.getSchemaName();
String tableName = table.getTableName();

StatisticsDao statisticsDao = new StatisticsDao(handle);
Long tableObjectId = statisticsDao.getTableObjectId(catalog, schema, tableName);
if (tableObjectId == null) {
// Table not found
return TableStatistics.empty();
}

Long rowCount = statisticsDao.getRowCount(tableObjectId);
if (rowCount == null) {
// Table disappeared
return TableStatistics.empty();
}

if (rowCount == 0) {
return TableStatistics.empty();
}

TableStatistics.Builder tableStatistics = TableStatistics.builder();
tableStatistics.setRowCount(Estimate.of(rowCount));

Map<String, String> columnNameToStatisticsName = getColumnNameToStatisticsName(table, statisticsDao, tableObjectId);

for (JdbcColumnHandle column : this.getColumns(session, table)) {
String statisticName = columnNameToStatisticsName.get(column.getColumnName());
if (statisticName == null) {
// No statistic for column
continue;
}

double averageColumnLength;
long notNullValues = 0;
long nullValues = 0;
long distinctValues = 0;

try (CallableStatement showStatistics = handle.getConnection().prepareCall("DBCC SHOW_STATISTICS (?, ?)")) {
showStatistics.setString(1, format("%s.%s.%s", catalog, schema, tableName));
showStatistics.setString(2, statisticName);

boolean isResultSet = showStatistics.execute();
checkState(isResultSet, "Expected SHOW_STATISTICS to return a result set");
try (ResultSet resultSet = showStatistics.getResultSet()) {
checkState(resultSet.next(), "No rows in result set");

averageColumnLength = resultSet.getDouble("Average Key Length"); // NULL values are accounted for with length 0

checkState(!resultSet.next(), "More than one row in result set");
}

isResultSet = showStatistics.getMoreResults();
checkState(isResultSet, "Expected SHOW_STATISTICS to return second result set");
showStatistics.getResultSet().close();

isResultSet = showStatistics.getMoreResults();
checkState(isResultSet, "Expected SHOW_STATISTICS to return third result set");
try (ResultSet resultSet = showStatistics.getResultSet()) {
while (resultSet.next()) {
resultSet.getObject("RANGE_HI_KEY");
if (resultSet.wasNull()) {
// Null fraction
checkState(resultSet.getLong("RANGE_ROWS") == 0, "Unexpected RANGE_ROWS for null fraction");
checkState(resultSet.getLong("DISTINCT_RANGE_ROWS") == 0, "Unexpected DISTINCT_RANGE_ROWS for null fraction");
checkState(nullValues == 0, "Multiple null fraction entries");
nullValues += resultSet.getLong("EQ_ROWS");
}
else {
// TODO discover min/max from resultSet.getXxx("RANGE_HI_KEY")
notNullValues += resultSet.getLong("RANGE_ROWS") // rows strictly within a bucket
+ resultSet.getLong("EQ_ROWS"); // rows equal to RANGE_HI_KEY
distinctValues += resultSet.getLong("DISTINCT_RANGE_ROWS") // NDV strictly within a bucket
+ (resultSet.getLong("EQ_ROWS") > 0 ? 1 : 0);
}
}
}
}

ColumnStatistics statistics = ColumnStatistics.builder()
.setNullsFraction(Estimate.of(
(notNullValues + nullValues == 0)
? 1
: (1.0 * nullValues / (notNullValues + nullValues))))
.setDistinctValuesCount(Estimate.of(distinctValues))
.setDataSize(Estimate.of(rowCount * averageColumnLength))
.build();

tableStatistics.setColumnStatistics(column, statistics);
}

return tableStatistics.build();
}
}

private static Map<String, String> getColumnNameToStatisticsName(JdbcTableHandle table, StatisticsDao statisticsDao, Long tableObjectId)
{
List<String> singleColumnStatistics = statisticsDao.getSingleColumnStatistics(tableObjectId);

Map<String, String> columnNameToStatisticsName = new HashMap<>();
for (String statisticName : singleColumnStatistics) {
String columnName = statisticsDao.getSingleColumnStatisticsColumnName(tableObjectId, statisticName);
if (columnName == null) {
// Table or statistics disappeared
continue;
}

if (columnNameToStatisticsName.putIfAbsent(columnName, statisticName) != null) {
log.debug("Multiple statistics for %s in %s: %s and %s", columnName, table, columnNameToStatisticsName.get(columnName), statisticName);
}
}
return columnNameToStatisticsName;
}

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
return implementJoinCostAware(
session,
joinType,
leftSource,
rightSource,
statistics,
() -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics));
}

private LongWriteFunction sqlServerTimeWriteFunction(int precision)
{
return new LongWriteFunction()
Expand Down Expand Up @@ -833,4 +1015,65 @@ private enum SnapshotIsolationEnabledCacheKey
// database, so from our perspective, this is a global property.
INSTANCE
}

private static class StatisticsDao
{
private final Handle handle;

public StatisticsDao(Handle handle)
{
this.handle = requireNonNull(handle, "handle is null");
}

Long getTableObjectId(String catalog, String schema, String tableName)
{
return handle.createQuery("SELECT object_id(:table)")
.bind("table", format("%s.%s.%s", catalog, schema, tableName))
.mapTo(Long.class)
.findOnly();
}

Long getRowCount(long tableObjectId)
{
return handle.createQuery("" +
"SELECT sum(rows) row_count " +
"FROM sys.partitions " +
"WHERE object_id = :object_id " +
"AND index_id IN (0, 1)") // 0 = heap, 1 = clustered index, 2 or greater = non-clustered index
.bind("object_id", tableObjectId)
.mapTo(Long.class)
.findOnly();
}

List<String> getSingleColumnStatistics(long tableObjectId)
{
return handle.createQuery("" +
"SELECT s.name " +
"FROM sys.stats AS s " +
"JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " +
"WHERE s.object_id = :object_id " +
"GROUP BY s.name " +
"HAVING count(*) = 1 " +
"ORDER BY s.name")
.bind("object_id", tableObjectId)
.mapTo(String.class)
.list();
}

String getSingleColumnStatisticsColumnName(long tableObjectId, String statisticsName)
{
return handle.createQuery("" +
"SELECT c.name " +
"FROM sys.stats AS s " +
"JOIN sys.stats_columns AS sc ON s.object_id = sc.object_id AND s.stats_id = sc.stats_id " +
"JOIN sys.columns AS c ON sc.object_id = c.object_id AND c.column_id = sc.column_id " +
"WHERE s.object_id = :object_id " +
"AND s.name = :statistics_name")
.bind("object_id", tableObjectId)
.bind("statistics_name", statisticsName)
.mapTo(String.class)
.collect(toOptional()) // verify there is no more than 1 column name returned
.orElse(null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@

import com.google.inject.Binder;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import com.microsoft.sqlserver.jdbc.SQLServerDriver;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.MaxDomainCompactionThreshold;
import io.trino.plugin.jdbc.credential.CredentialProvider;

Expand All @@ -34,15 +36,17 @@
import static io.trino.plugin.sqlserver.SqlServerClient.SQL_SERVER_MAX_LIST_EXPRESSIONS;

public class SqlServerClientModule
implements Module
extends AbstractConfigurationAwareModule
{
@Override
public void configure(Binder binder)
protected void setup(Binder binder)
{
configBinder(binder).bindConfig(SqlServerConfig.class);
configBinder(binder).bindConfig(JdbcStatisticsConfig.class);
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON);
bindTablePropertiesProvider(binder, SqlServerTableProperties.class);
newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS);
install(new JdbcJoinPushdownSupportModule());
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,4 +539,13 @@ private String getLongInClause(int start, int length)
.collect(joining(", "));
return "orderkey IN (" + longValues + ")";
}

@Override
protected Session joinPushdownEnabled(Session session)
{
return Session.builder(super.joinPushdownEnabled(session))
// strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected)
.setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER")
.build();
}
}