Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature][connectors][PIP-193] Support Transform Function with LocalRunner #17445

Merged
merged 1 commit into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,11 @@ public void testPulsarSourceLocalRunMultipleInstances() throws Throwable {
}

private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, String className) throws Exception {
testPulsarSinkLocalRun(jarFilePathUrl, parallelism, className, null, null);
}

private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, String className,
String transformFunction, String transformFunctionClassName) throws Exception {
final String namespacePortion = "io";
final String replNamespace = tenant + "/" + namespacePortion;
final String sourceTopic = "persistent://" + replNamespace + "/input";
Expand All @@ -921,6 +926,9 @@ private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, Stri

sinkConfig.setArchive(jarFilePathUrl);
sinkConfig.setParallelism(parallelism);
sinkConfig.setTransformFunction(transformFunction);
sinkConfig.setTransformFunctionClassName(transformFunctionClassName);

int metricsPort = FunctionCommon.findAvailablePort();
@Cleanup
LocalRunner localRunner = LocalRunner.builder()
Expand All @@ -933,6 +941,7 @@ private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, Stri
.tlsHostNameVerificationEnabled(false)
.brokerServiceUrl(pulsar.getBrokerServiceUrlTls())
.connectorsDirectory(workerConfig.getConnectorsDirectory())
.functionsDirectory(workerConfig.getFunctionsDirectory())
.metricsPortStart(metricsPort)
.build();

Expand Down Expand Up @@ -1083,6 +1092,12 @@ public void close() throws Exception {
public void testPulsarSinkStatsByteBufferType() throws Throwable {
runWithNarClassLoader(() -> testPulsarSinkLocalRun(null, 1, StatsNullSink.class.getName()));
}

//@Test(timeOut = 20000, groups = "builtin")
@Test(groups = "builtin")
public void testPulsarSinkWithFunction() throws Throwable {
testPulsarSinkLocalRun(null, 1, StatsNullSink.class.getName(), "builtin://exclamation", "org.apache.pulsar.functions.api.examples.RecordFunction");
}

public static class TestErrorSink implements Sink<byte[]> {
private Map config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pulsar.functions;

import static org.apache.commons.lang3.StringUtils.isNotEmpty;
import static org.apache.pulsar.common.functions.Utils.inferMissingArguments;
import com.beust.jcommander.IStringConverter;
import com.beust.jcommander.JCommander;
Expand Down Expand Up @@ -48,6 +49,7 @@
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Builder;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.apache.pulsar.common.functions.FunctionConfig;
import org.apache.pulsar.common.functions.Utils;
Expand Down Expand Up @@ -92,8 +94,8 @@ public class LocalRunner implements AutoCloseable {
private final String functionsDir;
private final Thread shutdownHook;
private final int instanceLivenessCheck;
private ClassLoader userCodeClassLoader;
private boolean userCodeClassLoaderCreated;
private UserCodeClassLoader userCodeClassLoader;
private UserCodeClassLoader transformFunctionCodeClassLoader;
private RuntimeFactory runtimeFactory;
private HTTPServer metricsServer;

Expand All @@ -102,6 +104,12 @@ public enum RuntimeEnv {
PROCESS
}

@Value
private static class UserCodeClassLoader {
ClassLoader classLoader;
boolean classLoaderCreated;
}

public static class FunctionConfigConverter implements IStringConverter<FunctionConfig> {
@Override
public FunctionConfig convert(String value) {
Expand Down Expand Up @@ -310,16 +318,21 @@ public synchronized void stop() {
runtimeFactory = null;
}

if (userCodeClassLoaderCreated) {
if (userCodeClassLoader instanceof Closeable) {
try {
((Closeable) userCodeClassLoader).close();
} catch (IOException e) {
log.warn("Error closing classloader", e);
}
closeClassLoaderIfneeded(userCodeClassLoader);
userCodeClassLoader = null;
closeClassLoaderIfneeded(transformFunctionCodeClassLoader);
transformFunctionCodeClassLoader = null;
}
}

private static void closeClassLoaderIfneeded(UserCodeClassLoader userCodeClassLoader) {
if (userCodeClassLoader != null && userCodeClassLoader.isClassLoaderCreated()) {
if (userCodeClassLoader.getClassLoader() instanceof Closeable) {
try {
((Closeable) userCodeClassLoader.getClassLoader()).close();
} catch (IOException e) {
log.warn("Error closing classloader", e);
}
userCodeClassLoaderCreated = false;
userCodeClassLoader = null;
}
}
}
Expand All @@ -333,16 +346,18 @@ public void start(boolean blocking) throws Exception {
Runtime.getRuntime().addShutdownHook(shutdownHook);
Function.FunctionDetails functionDetails = null;
String userCodeFile;
String transformFunctionFile = null;
int parallelism;
if (functionConfig != null) {
FunctionConfigUtils.inferMissingArguments(functionConfig, true);
parallelism = functionConfig.getParallelism();
if (functionConfig.getRuntime() == FunctionConfig.Runtime.JAVA) {
userCodeFile = functionConfig.getJar();
ClassLoader functionClassLoader = extractClassLoader(
userCodeClassLoader = extractClassLoader(
userCodeFile, ComponentType.FUNCTION, functionConfig.getClassName());
functionDetails = FunctionConfigUtils.convert(
functionConfig, FunctionConfigUtils.validateJavaFunction(functionConfig, functionClassLoader));
functionConfig,
FunctionConfigUtils.validateJavaFunction(functionConfig, getCurrentOrUserCodeClassLoader()));
} else if (functionConfig.getRuntime() == FunctionConfig.Runtime.GO) {
userCodeFile = functionConfig.getGo();
} else if (functionConfig.getRuntime() == FunctionConfig.Runtime.PYTHON) {
Expand All @@ -352,26 +367,42 @@ public void start(boolean blocking) throws Exception {
}

if (functionDetails == null) {
functionDetails = FunctionConfigUtils.convert(functionConfig,
userCodeClassLoader != null ? userCodeClassLoader :
Thread.currentThread().getContextClassLoader());
functionDetails = FunctionConfigUtils.convert(functionConfig, getCurrentOrUserCodeClassLoader());
}
} else if (sourceConfig != null) {
inferMissingArguments(sourceConfig);
userCodeFile = sourceConfig.getArchive();
parallelism = sourceConfig.getParallelism();
ClassLoader sourceClassLoader = extractClassLoader(
userCodeClassLoader = extractClassLoader(
userCodeFile, ComponentType.SOURCE, sourceConfig.getClassName());
functionDetails = SourceConfigUtils.convert(
sourceConfig, SourceConfigUtils.validateAndExtractDetails(sourceConfig, sourceClassLoader, true));
sourceConfig,
SourceConfigUtils.validateAndExtractDetails(sourceConfig, getCurrentOrUserCodeClassLoader(), true));
} else if (sinkConfig != null) {
inferMissingArguments(sinkConfig);
userCodeFile = sinkConfig.getArchive();
transformFunctionFile = sinkConfig.getTransformFunction();
parallelism = sinkConfig.getParallelism();
ClassLoader sinkClassLoader = extractClassLoader(
userCodeClassLoader = extractClassLoader(
userCodeFile, ComponentType.SINK, sinkConfig.getClassName());
if (isNotEmpty(sinkConfig.getTransformFunction())) {
transformFunctionCodeClassLoader = extractClassLoader(
sinkConfig.getTransformFunction(),
ComponentType.FUNCTION,
sinkConfig.getTransformFunctionClassName());
}

ClassLoader functionClassLoader = null;
if (transformFunctionCodeClassLoader != null) {
functionClassLoader = transformFunctionCodeClassLoader.getClassLoader() == null
? Thread.currentThread().getContextClassLoader()
: transformFunctionCodeClassLoader.getClassLoader();
}

functionDetails = SinkConfigUtils.convert(
sinkConfig, SinkConfigUtils.validateAndExtractDetails(sinkConfig, sinkClassLoader, null, true));
sinkConfig,
SinkConfigUtils.validateAndExtractDetails(sinkConfig, getCurrentOrUserCodeClassLoader(),
functionClassLoader, true));
} else {
throw new IllegalArgumentException("Must specify Function, Source or Sink config");
}
Expand Down Expand Up @@ -401,10 +432,10 @@ public void start(boolean blocking) throws Exception {
&& (runtimeEnv == null || runtimeEnv == RuntimeEnv.THREAD)) {
// By default run java functions as threads
startThreadedMode(functionDetails, parallelism, instanceIdOffset, serviceUrl,
stateStorageServiceUrl, authConfig, userCodeFile);
stateStorageServiceUrl, authConfig, userCodeFile, transformFunctionFile);
} else {
startProcessMode(functionDetails, parallelism, instanceIdOffset, serviceUrl,
stateStorageServiceUrl, authConfig, userCodeFile);
stateStorageServiceUrl, authConfig, userCodeFile, transformFunctionFile);
}
local.addAll(spawners);
}
Expand All @@ -426,15 +457,22 @@ public void start(boolean blocking) throws Exception {
}
}

private ClassLoader extractClassLoader(String userCodeFile, ComponentType componentType, String className)
private ClassLoader getCurrentOrUserCodeClassLoader() {
return userCodeClassLoader == null || userCodeClassLoader.getClassLoader() == null
? Thread.currentThread().getContextClassLoader()
: userCodeClassLoader.getClassLoader();
}

private UserCodeClassLoader extractClassLoader(String userCodeFile, ComponentType componentType, String className)
throws IOException, URISyntaxException {
userCodeClassLoader = userCodeFile != null ? isBuiltIn(userCodeFile, componentType) : null;
if (userCodeClassLoader == null) {
ClassLoader classLoader = userCodeFile != null ? isBuiltIn(userCodeFile, componentType) : null;
boolean classLoaderCreated = false;
if (classLoader == null) {
if (userCodeFile != null && Utils.isFunctionPackageUrlSupported(userCodeFile)) {
File file = FunctionCommon.extractFileFromPkgURL(userCodeFile);
userCodeClassLoader = FunctionCommon.getClassLoaderFromPackage(
classLoader = FunctionCommon.getClassLoaderFromPackage(
componentType, className, file, narExtractionDirectory);
userCodeClassLoaderCreated = true;
classLoaderCreated = true;
} else if (userCodeFile != null) {
File file = new File(userCodeFile);
if (!file.exists()) {
Expand All @@ -454,9 +492,9 @@ private ClassLoader extractClassLoader(String userCodeFile, ComponentType compon
}
throw new RuntimeException(errorMsg + " (" + userCodeFile + ") does not exist");
}
userCodeClassLoader = FunctionCommon.getClassLoaderFromPackage(
classLoader = FunctionCommon.getClassLoaderFromPackage(
componentType, className, file, narExtractionDirectory);
userCodeClassLoaderCreated = true;
classLoaderCreated = true;
} else {
if (!(runtimeEnv == null || runtimeEnv == RuntimeEnv.THREAD)) {
String errorMsg;
Expand All @@ -477,15 +515,13 @@ private ClassLoader extractClassLoader(String userCodeFile, ComponentType compon
}
}
}
return userCodeClassLoader == null
? Thread.currentThread().getContextClassLoader()
: userCodeClassLoader;
return new UserCodeClassLoader(classLoader, classLoaderCreated);
}

private void startProcessMode(org.apache.pulsar.functions.proto.Function.FunctionDetails functionDetails,
int parallelism, int instanceIdOffset, String serviceUrl,
String stateStorageServiceUrl, AuthenticationConfig authConfig,
String userCodeFile) throws Exception {
String userCodeFile, String transformFunctionFile) throws Exception {
SecretsProviderConfigurator secretsProviderConfigurator = getSecretsProviderConfigurator();
runtimeFactory = new ProcessRuntimeFactory(
serviceUrl,
Expand Down Expand Up @@ -532,7 +568,7 @@ private void startProcessMode(org.apache.pulsar.functions.proto.Function.Functio
instanceConfig,
userCodeFile,
null,
null,
transformFunctionFile,
null,
runtimeFactory,
instanceLivenessCheck);
Expand Down Expand Up @@ -568,7 +604,7 @@ public void run() {
private void startThreadedMode(org.apache.pulsar.functions.proto.Function.FunctionDetails functionDetails,
int parallelism, int instanceIdOffset, String serviceUrl,
String stateStorageServiceUrl, AuthenticationConfig authConfig,
String userCodeFile) throws Exception {
String userCodeFile, String transformFunctionFile) throws Exception {

if (metricsPortStart != null) {
if (metricsPortStart < 0 || metricsPortStart > 65535) {
Expand Down Expand Up @@ -599,8 +635,8 @@ private void startThreadedMode(org.apache.pulsar.functions.proto.Function.Functi

ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
try {
if (userCodeClassLoader != null) {
Thread.currentThread().setContextClassLoader(userCodeClassLoader);
if (userCodeClassLoader != null && userCodeClassLoader.getClassLoader() != null) {
Thread.currentThread().setContextClassLoader(userCodeClassLoader.getClassLoader());
}
runtimeFactory = new ThreadRuntimeFactory("LocalRunnerThreadGroup",
serviceUrl,
Expand All @@ -620,6 +656,7 @@ private void startThreadedMode(org.apache.pulsar.functions.proto.Function.Functi
// TODO: correctly implement function version and id
instanceConfig.setFunctionVersion(UUID.randomUUID().toString());
instanceConfig.setFunctionId(UUID.randomUUID().toString());
instanceConfig.setTransformFunctionId(UUID.randomUUID().toString());
instanceConfig.setInstanceId(i + instanceIdOffset);
instanceConfig.setMaxBufferedTuples(1024);
if (metricsPortStart != null) {
Expand All @@ -638,7 +675,7 @@ private void startThreadedMode(org.apache.pulsar.functions.proto.Function.Functi
instanceConfig,
userCodeFile,
null,
null,
transformFunctionFile,
null,
runtimeFactory,
instanceLivenessCheck);
Expand Down