Skip to content

Commit

Permalink
Add automatic JOIN pushdown support to SQL Server connector
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo authored and findepi committed Apr 8, 2022
1 parent 4393c81 commit e33ebc5
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import io.trino.plugin.jdbc.LongWriteFunction;
import io.trino.plugin.jdbc.ObjectReadFunction;
import io.trino.plugin.jdbc.ObjectWriteFunction;
import io.trino.plugin.jdbc.PreparedQuery;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.SliceWriteFunction;
Expand All @@ -59,6 +60,8 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.statistics.ColumnStatistics;
Expand Down Expand Up @@ -112,6 +115,7 @@
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware;
import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction;
Expand Down Expand Up @@ -610,6 +614,26 @@ private static Map<String, String> getColumnNameToStatisticsName(JdbcTableHandle
return columnNameToStatisticsName;
}

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
return implementJoinCostAware(
session,
joinType,
leftSource,
rightSource,
statistics,
() -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics));
}

private LongWriteFunction sqlServerTimeWriteFunction(int precision)
{
return new LongWriteFunction()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@

import com.google.inject.Binder;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import com.microsoft.sqlserver.jdbc.SQLServerDriver;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.MaxDomainCompactionThreshold;
import io.trino.plugin.jdbc.credential.CredentialProvider;
Expand All @@ -35,16 +36,17 @@
import static io.trino.plugin.sqlserver.SqlServerClient.SQL_SERVER_MAX_LIST_EXPRESSIONS;

public class SqlServerClientModule
implements Module
extends AbstractConfigurationAwareModule
{
@Override
public void configure(Binder binder)
protected void setup(Binder binder)
{
configBinder(binder).bindConfig(SqlServerConfig.class);
configBinder(binder).bindConfig(JdbcStatisticsConfig.class);
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(SqlServerClient.class).in(Scopes.SINGLETON);
bindTablePropertiesProvider(binder, SqlServerTableProperties.class);
newOptionalBinder(binder, Key.get(int.class, MaxDomainCompactionThreshold.class)).setBinding().toInstance(SQL_SERVER_MAX_LIST_EXPRESSIONS);
install(new JdbcJoinPushdownSupportModule());
}

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,4 +539,13 @@ private String getLongInClause(int start, int length)
.collect(joining(", "));
return "orderkey IN (" + longValues + ")";
}

@Override
protected Session joinPushdownEnabled(Session session)
{
return Session.builder(super.joinPushdownEnabled(session))
// strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected)
.setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER")
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.sqlserver;

import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest;
import io.trino.testing.QueryRunner;

import java.util.List;
import java.util.Map;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Streams.stream;
import static io.trino.plugin.sqlserver.SqlServerQueryRunner.createSqlServerQueryRunner;
import static java.lang.String.format;

public class TestSqlServerAutomaticJoinPushdown
extends BaseAutomaticJoinPushdownTest
{
private TestingSqlServer sqlServer;

@Override
protected QueryRunner createQueryRunner()
throws Exception
{
sqlServer = closeAfterClass(new TestingSqlServer());
return createSqlServerQueryRunner(sqlServer, Map.of(), Map.of(), List.of());
}

@Override
protected void gatherStats(String tableName)
{
List<String> columnNames = stream(computeActual("SHOW COLUMNS FROM " + tableName))
.map(row -> (String) row.getField(0))
.map(columnName -> "\"" + columnName + "\"")
.collect(toImmutableList());

for (String columnName : columnNames) {
sqlServer.execute(format("CREATE STATISTICS %1$s ON %2$s (%1$s)", columnName, tableName));
}

sqlServer.execute("UPDATE STATISTICS " + tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ protected QueryRunner createQueryRunner()
return SqlServerQueryRunner.createSqlServerQueryRunner(
sqlServer,
Map.of(),
Map.of(
"case-insensitive-name-matching", "true",
"join-pushdown.enabled", "true"),
Map.of("case-insensitive-name-matching", "true"),
List.of(ORDERS));
}

Expand Down

0 comments on commit e33ebc5

Please sign in to comment.