From dc27ddcb80f29a2da23ad370f138a2b5b21ae585 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 22 Sep 2022 10:00:05 -0500 Subject: [PATCH 1/5] Refactor: Allow workerd::Server::makeExternalService() to return a non-promise. The next commit will do the same for `makeWorker()`, so that services are constructed entirely synchronously, which is needed to simplify cross-linking service bindings. --- src/workerd/server/server.c++ | 72 ++++++++++++++++++++++++++++------- src/workerd/server/server.h | 2 +- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 573fc988e5f..38861309f46 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -389,6 +389,57 @@ kj::Own Server::makeInvalidConfigService() { return { invalidConfigServiceSingleton.get(), kj::NullDisposer::instance }; } +class PromisedNetworkAddress final: public kj::NetworkAddress { + // A NetworkAddress whose connect() method waits for a Promise and then forwards + // to it. Used by ExternalHttpService so that we don't have to wait for DNS lookup before the + // server can start. + // + // TODO(cleanup): kj::Network should be extended with a new version of parseAddress() which does + // not do DNS lookup immediately, and therefore can return a NetworkAddress synchronously. + // In fact, this version should be designed to redo the DNS lookup periodically to see if it + // changed, which would be nice for workerd when the remote address may change over time. +public: + PromisedNetworkAddress(kj::Promise> promise) + : promise(promise.then([this](kj::Own result) { + addr = kj::mv(result); + }).fork()) {} + + kj::Promise> connect() override { + KJ_IF_MAYBE(a, addr) { + return a->get()->connect(); + } else { + return promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(addr)->connect(); + }); + } + } + + kj::Promise connectAuthenticated() override { + KJ_IF_MAYBE(a, addr) { + return a->get()->connectAuthenticated(); + } else { + return promise.addBranch().then([this]() { + return KJ_ASSERT_NONNULL(addr)->connectAuthenticated(); + }); + } + } + + // We don't use any other methods, and they seem kinda annoying to implement. + kj::Own listen() override { + KJ_UNIMPLEMENTED("PromisedNetworkAddress::listen() not implemented"); + } + kj::Own clone() override { + KJ_UNIMPLEMENTED("PromisedNetworkAddress::clone() not implemented"); + } + kj::String toString() override { + KJ_UNIMPLEMENTED("PromisedNetworkAddress::toString() not implemented"); + } + +private: + kj::ForkedPromise promise; + kj::Maybe> addr; +}; + class Server::ExternalHttpService final: public Service { // Service used when the service's config is invalid. @@ -481,7 +532,7 @@ private: }; }; -kj::Promise> Server::makeExternalService( +kj::Own Server::makeExternalService( kj::StringPtr name, config::ExternalServer::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder) { kj::StringPtr addrStr = nullptr; @@ -504,12 +555,9 @@ kj::Promise> Server::makeExternalService( // We have to construct the rewriter upfront before waiting on any promises, since the // HeaderTable::Builder is only available synchronously. auto rewriter = kj::heap(conf.getHttp(), headerTableBuilder); - return network.parseAddress(addrStr, 80) - .then([this, rewriter = kj::mv(rewriter)](kj::Own addr) mutable - -> kj::Own { - return kj::heap( - kj::mv(addr), kj::mv(rewriter), globalContext->headerTable, timer, entropySource); - }); + auto addr = kj::heap(network.parseAddress(addrStr, 80)); + return kj::heap( + kj::mv(addr), kj::mv(rewriter), globalContext->headerTable, timer, entropySource); } case config::ExternalServer::HTTPS: { auto httpsConf = conf.getHttps(); @@ -518,12 +566,10 @@ kj::Promise> Server::makeExternalService( certificateHost = httpsConf.getCertificateHost(); } auto rewriter = kj::heap(httpsConf.getOptions(), headerTableBuilder); - return makeTlsNetworkAddress(httpsConf.getTlsOptions(), addrStr, certificateHost, 443) - .then([this, rewriter = kj::mv(rewriter)](kj::Own addr) mutable - -> kj::Own { - return kj::heap( - kj::mv(addr), kj::mv(rewriter), globalContext->headerTable, timer, entropySource); - }); + auto addr = kj::heap( + makeTlsNetworkAddress(httpsConf.getTlsOptions(), addrStr, certificateHost, 443)); + return kj::heap( + kj::mv(addr), kj::mv(rewriter), globalContext->headerTable, timer, entropySource); } } reportConfigError(kj::str( diff --git a/src/workerd/server/server.h b/src/workerd/server/server.h index 806cfa925b7..6a782bfe84e 100644 --- a/src/workerd/server/server.h +++ b/src/workerd/server/server.h @@ -93,7 +93,7 @@ class Server: private kj::TaskSet::ErrorHandler { class HttpRewriter; kj::Own makeInvalidConfigService(); - kj::Promise> makeExternalService( + kj::Own makeExternalService( kj::StringPtr name, config::ExternalServer::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder); kj::Own makeNetworkService(config::Network::Reader conf); From 85bfc1ad38b50931ff83204a5befaf72dbdb5a28 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 21 Sep 2022 19:28:59 -0500 Subject: [PATCH 2/5] Fix cyclic service bindings causing `workerd` to hang at startup. Previously, WorkerService's constructor took a list of `Service`s implemeting subrequest channels, which implied that those services had to be constructed before the WorkerService, and so cycles were impossible. We now split into two stages, construction and linking. We construct all services first, then we link them. This also means we no longer have any service types that need to be constructed asynchronously, which improves error handling since all errors will be reported by the time we've fininished constructing+linking. --- src/workerd/server/server-test.c++ | 90 +++++++++++++++ src/workerd/server/server.c++ | 179 +++++++++++++++++------------ src/workerd/server/server.h | 13 +-- 3 files changed, 199 insertions(+), 83 deletions(-) diff --git a/src/workerd/server/server-test.c++ b/src/workerd/server/server-test.c++ index 04586c36e00..1a4111df03b 100644 --- a/src/workerd/server/server-test.c++ +++ b/src/workerd/server/server-test.c++ @@ -1036,6 +1036,63 @@ KJ_TEST("Server: capability bindings") { )"_blockquote); } +KJ_TEST("Server: cyclic bindings") { + TestServer test(R"(( + services = [ + ( name = "service1", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` if (request.url.endsWith("/done")) { + ` return new Response("!"); + ` } else { + ` let resp2 = await env.service2.fetch(request); + ` let text = await resp2.text(); + ` return new Response("Hello " + text); + ` } + ` } + `} + ) + ], + bindings = [(name = "service2", service = "service2")] + ) + ), + ( name = "service2", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` let resp2 = await env.service1.fetch("http://foo/done"); + ` let text = await resp2.text(); + ` return new Response("World" + text); + ` } + `} + ) + ], + bindings = [(name = "service1", service = "service1")] + ) + ), + ], + sockets = [ + ( name = "main", + address = "test-addr", + service = "service1" + ) + ] + ))"_kj); + + test.start(); + auto conn = test.connect("test-addr"); + conn.httpGet200("/", "Hello World!"); +} + KJ_TEST("Server: named entrypoints") { TestServer test(R"(( services = [ @@ -1090,6 +1147,39 @@ KJ_TEST("Server: named entrypoints") { } } +KJ_TEST("Server: invalid entrypoint") { + TestServer test(R"(( + services = [ + ( name = "hello", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` return env.svc.fetch(request); + ` } + `} + ) + ], + bindings = [(name = "svc", service = (name = "hello", entrypoint = "bar"))], + ) + ), + ], + sockets = [ + ( name = "main", address = "test-addr", service = "hello" ), + ( name = "alt1", address = "foo-addr", service = (name = "hello", entrypoint = "foo")), + ] + ))"_kj); + + test.expectErrors( + "Worker \"hello\"'s binding \"svc\" refers to service \"hello\" with a named entrypoint " + "\"bar\", but \"hello\" has no such named entrypoint.\n" + "Socket \"alt1\" refers to service \"hello\" with a named entrypoint \"foo\", but \"hello\" " + "has no such named entrypoint.\n"); +} + // ======================================================================================= // Test HttpOptions on receive diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 38861309f46..823d512f9f5 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -133,6 +133,9 @@ struct Server::GlobalContext { class Server::Service { public: + virtual void link() {} + // Cross-links this service with other services. Must be called once before `startRequest()`. + virtual kj::Own startRequest( IoChannelFactory::SubrequestMetadata metadata) = 0; // Begin an incoming request. Returns a `WorkerInterface` object that will be used for one @@ -868,18 +871,29 @@ class Server::WorkerService final: public Service, private kj::TaskSet::ErrorHan private IoChannelFactory, private TimerChannel, private LimitEnforcer { public: + struct LinkedIoChannels { + kj::Array> subrequest; + }; + WorkerService(ThreadContext& threadContext, kj::Own worker, - kj::Vector> subrequestChannels, - kj::HashSet namedEntrypoints) + kj::HashSet namedEntrypoints, + kj::Function linkCallback) : threadContext(threadContext), worker(kj::mv(worker)), - subrequestChannels(kj::mv(subrequestChannels)), namedEntrypoints(kj::mv(namedEntrypoints)), + ioChannels(kj::mv(linkCallback)), waitUntilTasks(*this) {} bool hasEntrypoint(kj::StringPtr name) { return namedEntrypoints.contains(name); } + void link() override { + kj::Function callback = + kj::mv(KJ_REQUIRE_NONNULL(ioChannels.tryGet>(), + "already called link()")); + ioChannels = callback(); + } + kj::Own startRequest( IoChannelFactory::SubrequestMetadata metadata) override { return startRequest(kj::mv(metadata), nullptr); @@ -906,8 +920,8 @@ public: private: ThreadContext& threadContext; kj::Own worker; - kj::Vector> subrequestChannels; kj::HashSet namedEntrypoints; + kj::OneOf, LinkedIoChannels> ioChannels; kj::TaskSet waitUntilTasks; // --------------------------------------------------------------------------- @@ -921,8 +935,11 @@ private: // implements IoChannelFactory kj::Own startSubrequest(uint channel, SubrequestMetadata metadata) override { - KJ_REQUIRE(channel < subrequestChannels.size(), "invalid subrequest channel number"); - return subrequestChannels[channel]->startRequest(kj::mv(metadata)); + auto& channels = KJ_REQUIRE_NONNULL(ioChannels.tryGet(), + "link() has not been called"); + + KJ_REQUIRE(channel < channels.subrequest.size(), "invalid subrequest channel number"); + return channels.subrequest[channel]->startRequest(kj::mv(metadata)); } capnp::Capability::Client getCapability(uint channel) override { @@ -989,11 +1006,7 @@ private: void reportMetrics(RequestObserver& requestMetrics) override {} }; -kj::Promise> Server::makeWorker( - kj::StringPtr name, config::Worker::Reader conf) { - // Wait for next turn of the event loop to make sure `services` is fully initialized. - co_await kj::evalLater([]() {}); - +kj::Own Server::makeWorker(kj::StringPtr name, config::Worker::Reader conf) { struct ErrorReporter: public Worker::ValidationErrorReporter { ErrorReporter(Server& server, kj::StringPtr name): server(server), name(name) {} @@ -1080,17 +1093,11 @@ kj::Promise> Server::makeWorker( IsolateObserver::StartType::COLD, false, errorReporter); - kj::Vector> subrequestChannels; - { - auto service = co_await lookupService(conf.getGlobalOutbound(), - kj::str("Worker \"", name, "\"'s globalOutbound")); - - // Bind both "next" and "null" to the global outbound. (The difference between these is a - // legacy artifact that no one should be depending on.) Since all `subrequestChannels` will - // have the same lifetime, we can alias using a NullDisposer as a hack here. - subrequestChannels.add(kj::Own(service.get(), kj::NullDisposer::instance)); - subrequestChannels.add(kj::mv(service)); - } + struct FutureSubrequestChannel { + config::ServiceDesignator::Reader designator; + kj::String errorContext; + }; + kj::Vector subrequestChannels; auto confBindings = conf.getBindings(); using Global = WorkerdApiIsolate::Global; @@ -1231,15 +1238,17 @@ kj::Promise> Server::makeWorker( } case config::Worker::Binding::SERVICE: { - auto service = co_await lookupService(binding.getService(), kj::mv(errorContext)); - addGlobal(Global::Fetcher { - .channel = (uint)subrequestChannels.size(), + .channel = (uint)subrequestChannels.size() + + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT, .requiresHost = true, .isInHouse = false }); - subrequestChannels.add(kj::mv(service)); + subrequestChannels.add(FutureSubrequestChannel { + binding.getService(), + kj::mv(errorContext) + }); continue; } @@ -1247,35 +1256,41 @@ kj::Promise> Server::makeWorker( KJ_UNIMPLEMENTED("TODO(launch): durable object namespaces"); case config::Worker::Binding::KV_NAMESPACE: { - auto service = co_await lookupService(binding.getKvNamespace(), kj::mv(errorContext)); - addGlobal(Global::KvNamespace { - .subrequestChannel = (uint)subrequestChannels.size() + .subrequestChannel = (uint)subrequestChannels.size() + + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT }); - subrequestChannels.add(kj::mv(service)); + subrequestChannels.add(FutureSubrequestChannel { + binding.getKvNamespace(), + kj::mv(errorContext) + }); continue; } case config::Worker::Binding::R2_BUCKET: { - auto service = co_await lookupService(binding.getR2Bucket(), kj::mv(errorContext)); - addGlobal(Global::R2Bucket { - .subrequestChannel = (uint)subrequestChannels.size() + .subrequestChannel = (uint)subrequestChannels.size() + + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT }); - subrequestChannels.add(kj::mv(service)); + subrequestChannels.add(FutureSubrequestChannel { + binding.getR2Bucket(), + kj::mv(errorContext) + }); continue; } case config::Worker::Binding::R2_ADMIN: { - auto service = co_await lookupService(binding.getR2Admin(), kj::mv(errorContext)); - addGlobal(Global::R2Admin { - .subrequestChannel = (uint)subrequestChannels.size() + .subrequestChannel = (uint)subrequestChannels.size() + + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT }); - subrequestChannels.add(kj::mv(service)); + subrequestChannels.add(FutureSubrequestChannel { + binding.getR2Admin(), + kj::mv(errorContext) + }); continue; } } @@ -1301,14 +1316,38 @@ kj::Promise> Server::makeWorker( lock.validateHandlers(errorReporter); } - co_return kj::heap(globalContext->threadContext, kj::mv(worker), - kj::mv(subrequestChannels), - kj::mv(errorReporter.namedEntrypoints)); + auto linkCallback = + [this, name, conf, subrequestChannels = kj::mv(subrequestChannels)]() mutable { + auto services = kj::heapArrayBuilder>(subrequestChannels.size() + + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT); + + auto globalService = lookupService(conf.getGlobalOutbound(), + kj::str("Worker \"", name, "\"'s globalOutbound")); + + // Bind both "next" and "null" to the global outbound. (The difference between these is a + // legacy artifact that no one should be depending on.) Since all `subrequestChannels` will + // have the same lifetime, we can alias using a NullDisposer as a hack here. + static_assert(IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT == 2); + services.add(kj::Own(globalService.get(), kj::NullDisposer::instance)); + services.add(kj::mv(globalService)); + + for (auto& channel: subrequestChannels) { + services.add(lookupService(channel.designator, kj::mv(channel.errorContext))); + } + + return WorkerService::LinkedIoChannels { + .subrequest = services.finish() + }; + }; + + return kj::heap(globalContext->threadContext, kj::mv(worker), + kj::mv(errorReporter.namedEntrypoints), + kj::mv(linkCallback)); } // ======================================================================================= -kj::Promise> Server::makeService( +kj::Own Server::makeService( config::Service::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder) { kj::StringPtr name = conf.getName(); @@ -1357,41 +1396,38 @@ private: kj::String entrypoint; }; -kj::Promise> Server::lookupService( +kj::Own Server::lookupService( config::ServiceDesignator::Reader designator, kj::String errorContext) { - // Wait for next turn of the event loop to make sure `services` is fully initialized. - co_await kj::evalLater([]() {}); - kj::StringPtr targetName = designator.getName(); - Service* service = co_await KJ_UNWRAP_OR(services.find(targetName), { + Service* service = KJ_UNWRAP_OR(services.find(targetName), { reportConfigError(kj::str( errorContext, " refers to a service \"", targetName, "\", but no such service is defined.")); - co_return makeInvalidConfigService(); - }).addBranch(); + return makeInvalidConfigService(); + }); if (designator.hasEntrypoint()) { kj::StringPtr entrypointName = designator.getEntrypoint(); if (WorkerService* worker = dynamic_cast(service)) { if (worker->hasEntrypoint(entrypointName)) { - co_return kj::heap(*worker, entrypointName); + return kj::heap(*worker, entrypointName); } else { reportConfigError(kj::str( errorContext, " refers to service \"", targetName, "\" with a named entrypoint \"", entrypointName, "\", but \"", targetName, "\" has no such named entrypoint.")); - co_return makeInvalidConfigService(); + return makeInvalidConfigService(); } } else { reportConfigError(kj::str( errorContext, " refers to service \"", targetName, "\" with a named entrypoint \"", entrypointName, "\", but \"", targetName, "\" is not a Worker, so does not have any " "named entrypoints.")); - co_return makeInvalidConfigService(); + return makeInvalidConfigService(); } } else { // The service pointer we looked up is valid for the lifetime of the server, so we can wrap it // in a dummy Own. - co_return kj::Own(service, kj::NullDisposer::instance); + return kj::Own(service, kj::NullDisposer::instance); } } @@ -1570,17 +1606,11 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co // --------------------------------------------------------------------------- // Configure services - for (auto service: config.getServices()) { - kj::StringPtr name = service.getName(); + for (auto serviceConf: config.getServices()) { + kj::StringPtr name = serviceConf.getName(); + auto service = makeService(serviceConf, headerTableBuilder); - auto promise = makeService(service, headerTableBuilder) - .then([this](kj::Own service) { - return ownServices.add(kj::mv(service)).get(); - }).fork(); - - tasks.add(promise.addBranch().ignoreResult()); - - services.upsert(kj::str(name), kj::mv(promise), [&](auto&&...) { + services.upsert(kj::str(name), kj::mv(service), [&](auto&&...) { reportConfigError(kj::str("Config defines multiple services named \"", name, "\".")); }); } @@ -1595,17 +1625,21 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co auto tls = kj::heap(kj::mv(options)); auto tlsNetwork = tls->wrapNetwork(*publicNetwork).attach(kj::mv(tls)); - Service* ptr = ownServices.add(kj::heap( + auto service = kj::heap( globalContext->headerTable, timer, entropySource, - kj::mv(publicNetwork), kj::mv(tlsNetwork))) - .get(); + kj::mv(publicNetwork), kj::mv(tlsNetwork)); return decltype(services)::Entry { kj::str("internet"_kj), - kj::Promise(ptr).fork() + kj::mv(service) }; }); + // Now that all services are constructed, we can tell them to cross-link to each other. + for (auto& service: services) { + service.value->link(); + } + // --------------------------------------------------------------------------- // Start sockets @@ -1615,7 +1649,7 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co kj::String ownAddrStr; kj::Maybe> listenerOverride; - auto servicePromise = lookupService(sock.getService(), kj::str("Socket \"", name, "\"")); + auto service = lookupService(sock.getService(), kj::str("Socket \"", name, "\"")); KJ_IF_MAYBE(override, socketOverrides.findEntry(name)) { KJ_SWITCH_ONEOF(override->value) { @@ -1684,14 +1718,9 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co auto rewriter = kj::heap(httpOptions, headerTableBuilder); tasks.add(listener - .then([this, servicePromise = kj::mv(servicePromise), rewriter = kj::mv(rewriter), - physicalProtocol] + .then([this, service = kj::mv(service), rewriter = kj::mv(rewriter), physicalProtocol] (kj::Own listener) mutable { - return servicePromise - .then([this, listener = kj::mv(listener), rewriter = kj::mv(rewriter), physicalProtocol] - (kj::Own service) mutable { - return listenHttp(kj::mv(listener), kj::mv(service), physicalProtocol, kj::mv(rewriter)); - }); + return listenHttp(kj::mv(listener), kj::mv(service), physicalProtocol, kj::mv(rewriter)); })); } diff --git a/src/workerd/server/server.h b/src/workerd/server/server.h index 6a782bfe84e..f2ec0050d5c 100644 --- a/src/workerd/server/server.h +++ b/src/workerd/server/server.h @@ -72,10 +72,7 @@ class Server: private kj::TaskSet::ErrorHandler { class Service; kj::Own invalidConfigServiceSingleton; - kj::HashMap> services; - // Initialized synchronously in run() (before it returns a promise). - - kj::Vector> ownServices; + kj::HashMap> services; kj::Own> fatalFulfiller; @@ -100,14 +97,14 @@ class Server: private kj::TaskSet::ErrorHandler { kj::Own makeDiskDirectoryService( kj::StringPtr name, config::DiskDirectory::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder); - kj::Promise> makeWorker( - kj::StringPtr name, config::Worker::Reader conf); - kj::Promise> makeService( + kj::Own makeWorker(kj::StringPtr name, config::Worker::Reader conf); + kj::Own makeService( config::Service::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder); - kj::Promise> lookupService( + kj::Own lookupService( config::ServiceDesignator::Reader designator, kj::String errorContext); + // Can only be called in the link stage. kj::Promise listenHttp(kj::Own listener, kj::Own service, kj::StringPtr physicalProtocol, kj::Own rewriter); From a2ec9b8346a93a34896216286302c6b37f788e4d Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 22 Sep 2022 16:32:35 -0500 Subject: [PATCH 3/5] Cleanup: Server::lookupService() should return Service&, not Own. We can make it so all the service objects it returns are long-lived. This fixes a segfault on shutdown because destroying an `Own` that uses `NullDisposer` still requires the object is live if it is polymorphic, in order to `dynamic_cast` it. We can't easily ensure shutdown order here. --- src/workerd/server/server.c++ | 98 ++++++++++++++++++----------------- src/workerd/server/server.h | 5 +- 2 files changed, 52 insertions(+), 51 deletions(-) diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 823d512f9f5..96aeed6cc84 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -872,19 +872,24 @@ class Server::WorkerService final: public Service, private kj::TaskSet::ErrorHan private LimitEnforcer { public: struct LinkedIoChannels { - kj::Array> subrequest; + kj::Array subrequest; }; WorkerService(ThreadContext& threadContext, kj::Own worker, - kj::HashSet namedEntrypoints, + kj::HashSet namedEntrypointsParam, kj::Function linkCallback) : threadContext(threadContext), worker(kj::mv(worker)), - namedEntrypoints(kj::mv(namedEntrypoints)), ioChannels(kj::mv(linkCallback)), - waitUntilTasks(*this) {} + waitUntilTasks(*this) { + namedEntrypoints.reserve(namedEntrypointsParam.size()); + for (auto& ep: namedEntrypointsParam) { + kj::StringPtr epPtr = ep; + namedEntrypoints.insert(kj::mv(ep), EntrypointService(*this, epPtr)); + } + } - bool hasEntrypoint(kj::StringPtr name) { - return namedEntrypoints.contains(name); + kj::Maybe getEntrypoint(kj::StringPtr name) { + return namedEntrypoints.find(name); } void link() override { @@ -918,9 +923,24 @@ public: } private: + class EntrypointService final: public Service { + public: + EntrypointService(WorkerService& worker, kj::StringPtr entrypoint) + : worker(worker), entrypoint(entrypoint) {} + + kj::Own startRequest( + IoChannelFactory::SubrequestMetadata metadata) override { + return worker.startRequest(kj::mv(metadata), entrypoint); + } + + private: + WorkerService& worker; + kj::StringPtr entrypoint; + }; + ThreadContext& threadContext; kj::Own worker; - kj::HashSet namedEntrypoints; + kj::HashMap namedEntrypoints; kj::OneOf, LinkedIoChannels> ioChannels; kj::TaskSet waitUntilTasks; @@ -1318,21 +1338,20 @@ kj::Own Server::makeWorker(kj::StringPtr name, config::Worker:: auto linkCallback = [this, name, conf, subrequestChannels = kj::mv(subrequestChannels)]() mutable { - auto services = kj::heapArrayBuilder>(subrequestChannels.size() + + auto services = kj::heapArrayBuilder(subrequestChannels.size() + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT); - auto globalService = lookupService(conf.getGlobalOutbound(), + Service& globalService = lookupService(conf.getGlobalOutbound(), kj::str("Worker \"", name, "\"'s globalOutbound")); // Bind both "next" and "null" to the global outbound. (The difference between these is a - // legacy artifact that no one should be depending on.) Since all `subrequestChannels` will - // have the same lifetime, we can alias using a NullDisposer as a hack here. + // legacy artifact that no one should be depending on.) static_assert(IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT == 2); - services.add(kj::Own(globalService.get(), kj::NullDisposer::instance)); - services.add(kj::mv(globalService)); + services.add(&globalService); + services.add(&globalService); for (auto& channel: subrequestChannels) { - services.add(lookupService(channel.designator, kj::mv(channel.errorContext))); + services.add(&lookupService(channel.designator, kj::mv(channel.errorContext))); } return WorkerService::LinkedIoChannels { @@ -1381,53 +1400,36 @@ void Server::taskFailed(kj::Exception&& exception) { fatalFulfiller->reject(kj::mv(exception)); } -class Server::WorkerEntrypointService final: public Service { -public: - WorkerEntrypointService(WorkerService& worker, kj::StringPtr entrypoint) - : worker(worker), entrypoint(kj::str(entrypoint)) {} - - kj::Own startRequest( - IoChannelFactory::SubrequestMetadata metadata) override { - return worker.startRequest(kj::mv(metadata), entrypoint.asPtr()); - } - -private: - WorkerService& worker; - kj::String entrypoint; -}; - -kj::Own Server::lookupService( +Server::Service& Server::lookupService( config::ServiceDesignator::Reader designator, kj::String errorContext) { kj::StringPtr targetName = designator.getName(); Service* service = KJ_UNWRAP_OR(services.find(targetName), { reportConfigError(kj::str( errorContext, " refers to a service \"", targetName, "\", but no such service is defined.")); - return makeInvalidConfigService(); + return *invalidConfigServiceSingleton; }); if (designator.hasEntrypoint()) { kj::StringPtr entrypointName = designator.getEntrypoint(); if (WorkerService* worker = dynamic_cast(service)) { - if (worker->hasEntrypoint(entrypointName)) { - return kj::heap(*worker, entrypointName); + KJ_IF_MAYBE(ep, worker->getEntrypoint(entrypointName)) { + return *ep; } else { reportConfigError(kj::str( errorContext, " refers to service \"", targetName, "\" with a named entrypoint \"", entrypointName, "\", but \"", targetName, "\" has no such named entrypoint.")); - return makeInvalidConfigService(); + return *invalidConfigServiceSingleton; } } else { reportConfigError(kj::str( errorContext, " refers to service \"", targetName, "\" with a named entrypoint \"", entrypointName, "\", but \"", targetName, "\" is not a Worker, so does not have any " "named entrypoints.")); - return makeInvalidConfigService(); + return *invalidConfigServiceSingleton; } } else { - // The service pointer we looked up is valid for the lifetime of the server, so we can wrap it - // in a dummy Own. - return kj::Own(service, kj::NullDisposer::instance); + return *service; } } @@ -1435,10 +1437,10 @@ kj::Own Server::lookupService( class Server::HttpListener final: private kj::TaskSet::ErrorHandler { public: - HttpListener(kj::Own listener, kj::Own service, + HttpListener(kj::Own listener, Service& service, kj::StringPtr physicalProtocol, kj::Own rewriter, kj::HttpHeaderTable& headerTable, kj::Timer& timer) - : listener(kj::mv(listener)), service(kj::mv(service)), + : listener(kj::mv(listener)), service(service), headerTable(headerTable), timer(timer), physicalProtocol(physicalProtocol), rewriter(kj::mv(rewriter)), @@ -1492,7 +1494,7 @@ public: private: kj::Own listener; - kj::Own service; + Service& service; kj::HttpHeaderTable& headerTable; kj::Timer& timer; kj::StringPtr physicalProtocol; @@ -1559,11 +1561,11 @@ private: url, parent.physicalProtocol, headers, metadata.cfBlobJson), { return response.sendError(400, "Bad Request", parent.headerTable); }); - auto worker = parent.service->startRequest(kj::mv(metadata)); + auto worker = parent.service.startRequest(kj::mv(metadata)); return worker->request(method, url, *rewrite.headers, requestBody, *wrappedResponse) .attach(kj::mv(rewrite), kj::mv(worker), kj::mv(ownResponse)); } else { - auto worker = parent.service->startRequest(kj::mv(metadata)); + auto worker = parent.service.startRequest(kj::mv(metadata)); return worker->request(method, url, headers, requestBody, *wrappedResponse) .attach(kj::mv(worker), kj::mv(ownResponse)); } @@ -1585,9 +1587,9 @@ private: }; kj::Promise Server::listenHttp( - kj::Own listener, kj::Own service, + kj::Own listener, Service& service, kj::StringPtr physicalProtocol, kj::Own rewriter) { - auto obj = kj::heap(kj::mv(listener), kj::mv(service), + auto obj = kj::heap(kj::mv(listener), service, physicalProtocol, kj::mv(rewriter), globalContext->headerTable, timer); return obj->run().attach(kj::mv(obj)); @@ -1649,7 +1651,7 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co kj::String ownAddrStr; kj::Maybe> listenerOverride; - auto service = lookupService(sock.getService(), kj::str("Socket \"", name, "\"")); + Service& service = lookupService(sock.getService(), kj::str("Socket \"", name, "\"")); KJ_IF_MAYBE(override, socketOverrides.findEntry(name)) { KJ_SWITCH_ONEOF(override->value) { @@ -1718,9 +1720,9 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co auto rewriter = kj::heap(httpOptions, headerTableBuilder); tasks.add(listener - .then([this, service = kj::mv(service), rewriter = kj::mv(rewriter), physicalProtocol] + .then([this, &service, rewriter = kj::mv(rewriter), physicalProtocol] (kj::Own listener) mutable { - return listenHttp(kj::mv(listener), kj::mv(service), physicalProtocol, kj::mv(rewriter)); + return listenHttp(kj::mv(listener), service, physicalProtocol, kj::mv(rewriter)); })); } diff --git a/src/workerd/server/server.h b/src/workerd/server/server.h index f2ec0050d5c..e8a3e37eaaf 100644 --- a/src/workerd/server/server.h +++ b/src/workerd/server/server.h @@ -102,11 +102,10 @@ class Server: private kj::TaskSet::ErrorHandler { config::Service::Reader conf, kj::HttpHeaderTable::Builder& headerTableBuilder); - kj::Own lookupService( - config::ServiceDesignator::Reader designator, kj::String errorContext); + Service& lookupService(config::ServiceDesignator::Reader designator, kj::String errorContext); // Can only be called in the link stage. - kj::Promise listenHttp(kj::Own listener, kj::Own service, + kj::Promise listenHttp(kj::Own listener, Service& service, kj::StringPtr physicalProtocol, kj::Own rewriter); class InvalidConfigService; From d879226d8f2aa6e2df74fa4d6602ac88e78df2ac Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Fri, 23 Sep 2022 08:42:09 -0500 Subject: [PATCH 4/5] Support Durable Objects in workerd (in-memory only, for testing). --- src/workerd/server/server-test.c++ | 128 +++++++++++ src/workerd/server/server.c++ | 346 +++++++++++++++++++++++++++-- src/workerd/server/server.h | 9 + src/workerd/server/workerd-api.c++ | 122 ++++++++++ src/workerd/server/workerd-api.h | 19 +- 5 files changed, 604 insertions(+), 20 deletions(-) diff --git a/src/workerd/server/server-test.c++ b/src/workerd/server/server-test.c++ index 1a4111df03b..0a0fa816ca9 100644 --- a/src/workerd/server/server-test.c++ +++ b/src/workerd/server/server-test.c++ @@ -1180,6 +1180,134 @@ KJ_TEST("Server: invalid entrypoint") { "has no such named entrypoint.\n"); } +KJ_TEST("Server: Durable Objects") { + TestServer test(R"(( + services = [ + ( name = "hello", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` let id = env.ns.idFromName(request.url) + ` let actor = env.ns.get(id) + ` return await actor.fetch(request) + ` } + `} + `export class MyActorClass { + ` constructor(state, env) { + ` this.storage = state.storage; + ` this.id = state.id; + ` } + ` async fetch(request) { + ` let count = (await this.storage.get("foo")) || 0; + ` this.storage.put("foo", count + 1); + ` return new Response(this.id + ": " + request.url + " " + count); + ` } + `} + ) + ], + bindings = [(name = "ns", durableObjectNamespace = "MyActorClass")], + durableObjectNamespaces = [ + ( className = "MyActorClass", + uniqueKey = "mykey", + ) + ], + durableObjectStorage = (inMemory = void) + ) + ), + ], + sockets = [ + ( name = "main", + address = "test-addr", + service = "hello" + ) + ] + ))"_kj); + + test.start(); + auto conn = test.connect("test-addr"); + conn.httpGet200("/", + "59002eb8cf872e541722977a258a12d6a93bbe8192b502e1c0cb250aa91af234: http://foo/ 0"); + conn.httpGet200("/", + "59002eb8cf872e541722977a258a12d6a93bbe8192b502e1c0cb250aa91af234: http://foo/ 1"); + conn.httpGet200("/", + "59002eb8cf872e541722977a258a12d6a93bbe8192b502e1c0cb250aa91af234: http://foo/ 2"); + conn.httpGet200("/bar", + "02b496f65dd35cbac90e3e72dc5a398ee93926ea4a3821e26677082d2e6f9b79: http://foo/bar 0"); + conn.httpGet200("/bar", + "02b496f65dd35cbac90e3e72dc5a398ee93926ea4a3821e26677082d2e6f9b79: http://foo/bar 1"); + conn.httpGet200("/", + "59002eb8cf872e541722977a258a12d6a93bbe8192b502e1c0cb250aa91af234: http://foo/ 3"); + conn.httpGet200("/bar", + "02b496f65dd35cbac90e3e72dc5a398ee93926ea4a3821e26677082d2e6f9b79: http://foo/bar 2"); +} + +KJ_TEST("Server: Ephemeral Objects") { + TestServer test(R"(( + services = [ + ( name = "hello", + worker = ( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request, env) { + ` let actor = env.ns.get(request.url) + ` return await actor.fetch(request) + ` } + `} + `export class MyActorClass { + ` constructor(state, env) { + ` if (state.storage) throw new Error("storage shouldn't be present"); + ` this.id = state.id; + ` this.count = 0; + ` } + ` async fetch(request) { + ` return new Response(this.id + ": " + request.url + " " + this.count++); + ` } + `} + ) + ], + bindings = [(name = "ns", durableObjectNamespace = "MyActorClass")], + durableObjectNamespaces = [ + ( className = "MyActorClass", + ephemeralLocal = void, + ) + ], + durableObjectStorage = (inMemory = void) + ) + ), + ], + sockets = [ + ( name = "main", + address = "test-addr", + service = "hello" + ) + ] + ))"_kj); + + test.start(); + auto conn = test.connect("test-addr"); + conn.httpGet200("/", + "http://foo/: http://foo/ 0"); + conn.httpGet200("/", + "http://foo/: http://foo/ 1"); + conn.httpGet200("/", + "http://foo/: http://foo/ 2"); + conn.httpGet200("/bar", + "http://foo/bar: http://foo/bar 0"); + conn.httpGet200("/bar", + "http://foo/bar: http://foo/bar 1"); + conn.httpGet200("/", + "http://foo/: http://foo/ 3"); + conn.httpGet200("/bar", + "http://foo/bar: http://foo/bar 2"); +} + // ======================================================================================= // Test HttpOptions on receive diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 96aeed6cc84..d9284f96185 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -19,6 +19,7 @@ #include #include #include +#include #include "workerd-api.h" namespace workerd::server { @@ -103,6 +104,53 @@ static kj::Vector escapeJsonString(kj::StringPtr text) { return escaped; } +class EmptyReadOnlyActorStorageImpl final: public rpc::ActorStorage::Stage::Server { + // An ActorStorage implementation which will always respond to reads as if the state is empty, + // and will fail any writes. +public: + kj::Promise get(GetContext context) override { + return kj::READY_NOW; + } + kj::Promise getMultiple(GetMultipleContext context) override { + return context.getParams().getStream().endRequest(capnp::MessageSize {2, 0}) + .send().ignoreResult(); + } + kj::Promise list(ListContext context) override { + return context.getParams().getStream().endRequest(capnp::MessageSize {2, 0}) + .send().ignoreResult(); + } + kj::Promise getAlarm(GetAlarmContext context) override { + return kj::READY_NOW; + } + kj::Promise txn(TxnContext context) override { + auto results = context.getResults(capnp::MessageSize {2, 1}); + results.setTransaction(kj::heap()); + return kj::READY_NOW; + } + +private: + class TransactionImpl final: public rpc::ActorStorage::Stage::Transaction::Server { + protected: + kj::Promise get(GetContext context) override { + return kj::READY_NOW; + } + kj::Promise getMultiple(GetMultipleContext context) override { + return context.getParams().getStream().endRequest(capnp::MessageSize {2, 0}) + .send().ignoreResult(); + } + kj::Promise list(ListContext context) override { + return context.getParams().getStream().endRequest(capnp::MessageSize {2, 0}) + .send().ignoreResult(); + } + kj::Promise getAlarm(GetAlarmContext context) override { + return kj::READY_NOW; + } + kj::Promise commit(CommitContext context) override { + return kj::READY_NOW; + } + }; +}; + } // namespace // ======================================================================================= @@ -871,13 +919,19 @@ class Server::WorkerService final: public Service, private kj::TaskSet::ErrorHan private IoChannelFactory, private TimerChannel, private LimitEnforcer { public: + class ActorNamespace; + struct LinkedIoChannels { + // I/O channels, delivered when link() is called. kj::Array subrequest; + kj::Array> actor; // null = configuration error }; + using LinkCallback = kj::Function; WorkerService(ThreadContext& threadContext, kj::Own worker, kj::HashSet namedEntrypointsParam, - kj::Function linkCallback) + const kj::HashMap& actorClasses, + LinkCallback linkCallback) : threadContext(threadContext), worker(kj::mv(worker)), ioChannels(kj::mv(linkCallback)), waitUntilTasks(*this) { @@ -886,6 +940,12 @@ public: kj::StringPtr epPtr = ep; namedEntrypoints.insert(kj::mv(ep), EntrypointService(*this, epPtr)); } + + actorNamespaces.reserve(actorClasses.size()); + for (auto& entry: actorClasses) { + ActorNamespace ns(*this, entry.key, entry.value); + actorNamespaces.insert(entry.key, kj::mv(ns)); + } } kj::Maybe getEntrypoint(kj::StringPtr name) { @@ -893,10 +953,13 @@ public: } void link() override { - kj::Function callback = - kj::mv(KJ_REQUIRE_NONNULL(ioChannels.tryGet>(), - "already called link()")); - ioChannels = callback(); + LinkCallback callback = kj::mv(KJ_REQUIRE_NONNULL( + ioChannels.tryGet(), "already called link()")); + ioChannels = callback(*this); + } + + kj::Maybe getActorNamespace(kj::StringPtr name) { + return actorNamespaces.find(name); } kj::Own startRequest( @@ -905,12 +968,13 @@ public: } kj::Own startRequest( - IoChannelFactory::SubrequestMetadata metadata, kj::Maybe entrypointName) { + IoChannelFactory::SubrequestMetadata metadata, kj::Maybe entrypointName, + kj::Maybe> actor = nullptr) { return WorkerEntrypoint::construct( threadContext, kj::atomicAddRef(*worker), entrypointName, - nullptr, // actor -- TODO(launch): support preview actors + kj::mv(actor), kj::Own(this, kj::NullDisposer::instance), {}, // ioContextDependency kj::Own(this, kj::NullDisposer::instance), @@ -922,6 +986,69 @@ public: kj::mv(metadata.cfBlobJson)); } + class ActorNamespace { + public: + ActorNamespace(WorkerService& service, kj::StringPtr className, const ActorConfig& config) + : service(service), className(className), config(config) {} + + const ActorConfig& getConfig() { return config; } + + kj::Own getActor(Worker::Actor::Id id) { + // `getActor()` is often called with the calling isolate's lock held. We need to drop that + // lock and take a lock on the target isolate before constructing the actor. Even if these + // are the same isolate (as is commonly the case), we really don't want to do this stuff + // synchronously, so this has the effect of pushing off to a later turn of the event loop. + auto promise = service.worker->takeAsyncLockWithoutRequest(nullptr) + .then([this, id = kj::mv(id)](Worker::AsyncLock lock) mutable -> kj::Own { + kj::String idStr; + KJ_SWITCH_ONEOF(id) { + KJ_CASE_ONEOF(obj, kj::Own) { + KJ_REQUIRE(config.is()); + idStr = obj->toString(); + } + KJ_CASE_ONEOF(str, kj::String) { + KJ_REQUIRE(config.is()); + idStr = kj::str(str); + } + } + + auto actor = kj::addRef(*actors.findOrCreate(idStr, [&]() { + auto persistent = config.tryGet().map([&](const Durable& d) { + // TODO(someday): Implement some sort of actual durable storage. For now we force + // `ActorCache` into `neverFlush` mode so that all state is kept in-memory. + return rpc::ActorStorage::Stage::Client(kj::heap()); + }); + + auto makeStorage = [](jsg::Lock& js, const Worker::ApiIsolate& apiIsolate, + ActorCache& actorCache) + -> jsg::Ref { + return jsg::alloc(IoContext::current().addObject(actorCache)); + }; + + TimerChannel& timerChannel = service; + auto newActor = kj::refcounted( + *service.worker, kj::mv(id), true, kj::mv(persistent), + className, kj::mv(makeStorage), lock, + timerChannel, kj::refcounted()); + + return kj::HashMap>::Entry { + kj::mv(idStr), kj::mv(newActor) + }; + })); + + return kj::heap(service, className, kj::mv(actor)); + }); + + return kj::heap(service.waitUntilTasks, kj::mv(promise)); + } + + private: + WorkerService& service; + kj::StringPtr className; + const ActorConfig& config; + kj::HashMap> actors; + }; + private: class EntrypointService final: public Service { public: @@ -941,9 +1068,52 @@ private: ThreadContext& threadContext; kj::Own worker; kj::HashMap namedEntrypoints; - kj::OneOf, LinkedIoChannels> ioChannels; + kj::HashMap actorNamespaces; + kj::OneOf ioChannels; kj::TaskSet waitUntilTasks; + class ActorChannelImpl final: public IoChannelFactory::ActorChannel { + public: + ActorChannelImpl(WorkerService& service, kj::StringPtr className, kj::Own actor) + : service(service), className(className), actor(kj::mv(actor)) {} + + kj::Own startRequest( + IoChannelFactory::SubrequestMetadata metadata) override { + return service.startRequest(kj::mv(metadata), className, kj::addRef(*actor)); + } + + private: + WorkerService& service; + kj::StringPtr className; + kj::Own actor; + }; + + class PromisedActorChannel final: public IoChannelFactory::ActorChannel { + public: + PromisedActorChannel(kj::TaskSet& waitUntilTasks, kj::Promise> promise) + : waitUntilTasks(waitUntilTasks), + promise(promise.then([this](kj::Own result) { + channel = kj::mv(result); + }).fork()) {} + + kj::Own startRequest( + IoChannelFactory::SubrequestMetadata metadata) override { + KJ_IF_MAYBE(c, channel) { + return c->get()->startRequest(kj::mv(metadata)); + } else { + return newPromisedWorkerInterface(waitUntilTasks, + promise.addBranch().then([this, metadata = kj::mv(metadata)]() mutable { + return KJ_ASSERT_NONNULL(channel)->startRequest(kj::mv(metadata)); + })); + } + } + + private: + kj::TaskSet& waitUntilTasks; + kj::ForkedPromise promise; + kj::Maybe> channel; + }; + // --------------------------------------------------------------------------- // implements kj::TaskSet::ErrorHandler @@ -981,12 +1151,25 @@ private: } kj::Own getGlobalActor(uint channel, const ActorIdFactory::ActorId& id) override { - // TODO(launch): actors - KJ_FAIL_REQUIRE("no actor channels"); + auto& channels = KJ_REQUIRE_NONNULL(ioChannels.tryGet(), + "link() has not been called"); + + KJ_REQUIRE(channel < channels.actor.size(), "invalid actor channel number"); + auto& ns = JSG_REQUIRE_NONNULL(channels.actor[channel], Error, + "Actor namespace configuration was invalid."); + KJ_REQUIRE(ns.getConfig().is()); // should have been verified earlier + return ns.getActor(id.clone()); } kj::Own getColoLocalActor(uint channel, kj::String id) override { - KJ_FAIL_REQUIRE("no actor channels"); + auto& channels = KJ_REQUIRE_NONNULL(ioChannels.tryGet(), + "link() has not been called"); + + KJ_REQUIRE(channel < channels.actor.size(), "invalid actor channel number"); + auto& ns = JSG_REQUIRE_NONNULL(channels.actor[channel], Error, + "Actor namespace configuration was invalid."); + KJ_REQUIRE(ns.getConfig().is()); // should have been verified earlier + return ns.getActor(kj::str(id)); } // --------------------------------------------------------------------------- @@ -1027,6 +1210,8 @@ private: }; kj::Own Server::makeWorker(kj::StringPtr name, config::Worker::Reader conf) { + auto& localActorConfigs = KJ_ASSERT_NONNULL(actorConfigs.find(name)); + struct ErrorReporter: public Worker::ValidationErrorReporter { ErrorReporter(Server& server, kj::StringPtr name): server(server), name(name) {} @@ -1075,7 +1260,11 @@ kj::Own Server::makeWorker(kj::StringPtr name, config::Worker:: .hardLimit = 128ull << 20, .staleTimeout = 30 * kj::SECONDS, .dirtyKeySoftLimit = 64, - .maxKeysPerRpc = 128 + .maxKeysPerRpc = 128, + + // For now, we use `neverFlush` to implement in-memory-only actors. + // See WorkerService::getActor(). + .neverFlush = true }; } kj::Own enterStartupJs( @@ -1119,6 +1308,12 @@ kj::Own Server::makeWorker(kj::StringPtr name, config::Worker:: }; kj::Vector subrequestChannels; + struct FutureActorChannel { + config::Worker::Binding::DurableObjectNamespaceDesignator::Reader designator; + kj::String errorContext; + }; + kj::Vector actorChannels; + auto confBindings = conf.getBindings(); using Global = WorkerdApiIsolate::Global; kj::Vector globals(confBindings.size()); @@ -1272,8 +1467,54 @@ kj::Own Server::makeWorker(kj::StringPtr name, config::Worker:: continue; } - case config::Worker::Binding::DURABLE_OBJECT_NAMESPACE: - KJ_UNIMPLEMENTED("TODO(launch): durable object namespaces"); + case config::Worker::Binding::DURABLE_OBJECT_NAMESPACE: { + auto actorBinding = binding.getDurableObjectNamespace(); + const ActorConfig* actorConfig; + if (actorBinding.hasServiceName()) { + auto& svcMap = KJ_UNWRAP_OR(actorConfigs.find(actorBinding.getServiceName()), { + errorReporter.addError(kj::str( + errorContext, " refers to a service \"", actorBinding.getServiceName(), + "\", but no such service is defined.")); + continue; + }); + + actorConfig = &KJ_UNWRAP_OR(svcMap.find(actorBinding.getClassName()), { + errorReporter.addError(kj::str( + errorContext, " refers to a Durable Object namespace named \"", + actorBinding.getClassName(), "\" in service \"", actorBinding.getServiceName(), + "\", but no such Durable Object namespace is defined by that service.")); + continue; + }); + } else { + actorConfig = &KJ_UNWRAP_OR(localActorConfigs.find(actorBinding.getClassName()), { + errorReporter.addError(kj::str( + errorContext, " refers to a Durable Object namespace named \"", + actorBinding.getClassName(), "\", but no such Durable Object namespace is defined " + "by this Worker.")); + continue; + }); + } + + KJ_SWITCH_ONEOF(*actorConfig) { + KJ_CASE_ONEOF(durable, Durable) { + addGlobal(Global::DurableActorNamespace { + .actorChannel = (uint)actorChannels.size(), + .uniqueKey = durable.uniqueKey + }); + } + KJ_CASE_ONEOF(_, Ephemeral) { + addGlobal(Global::EphemeralActorNamespace { + .actorChannel = (uint)actorChannels.size(), + }); + } + } + + actorChannels.add(FutureActorChannel { + actorBinding, + kj::mv(errorContext) + }); + continue; + } case config::Worker::Binding::KV_NAMESPACE: { addGlobal(Global::KvNamespace { @@ -1337,7 +1578,8 @@ kj::Own Server::makeWorker(kj::StringPtr name, config::Worker:: } auto linkCallback = - [this, name, conf, subrequestChannels = kj::mv(subrequestChannels)]() mutable { + [this, name, conf, subrequestChannels = kj::mv(subrequestChannels), + actorChannels = kj::mv(actorChannels)](WorkerService& workerService) mutable { auto services = kj::heapArrayBuilder(subrequestChannels.size() + IoContext::SPECIAL_SUBREQUEST_CHANNEL_COUNT); @@ -1354,13 +1596,32 @@ kj::Own Server::makeWorker(kj::StringPtr name, config::Worker:: services.add(&lookupService(channel.designator, kj::mv(channel.errorContext))); } + auto actors = KJ_MAP(channel, actorChannels) -> kj::Maybe { + WorkerService* targetService = &workerService; + if (channel.designator.hasServiceName()) { + auto& svc = KJ_UNWRAP_OR(this->services.find(channel.designator.getServiceName()), { + // error was reported earlier + return nullptr; + }); + targetService = dynamic_cast(svc.get()); + if (targetService == nullptr) { + // error was reported earlier + return nullptr; + } + } + + // (If getActorNamespace() returns null, an error was reported earlier.) + return targetService->getActorNamespace(channel.designator.getClassName()); + }; + return WorkerService::LinkedIoChannels { - .subrequest = services.finish() + .subrequest = services.finish(), + .actor = kj::mv(actors) }; }; return kj::heap(globalContext->threadContext, kj::mv(worker), - kj::mv(errorReporter.namedEntrypoints), + kj::mv(errorReporter.namedEntrypoints), localActorConfigs, kj::mv(linkCallback)); } @@ -1608,6 +1869,55 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co // --------------------------------------------------------------------------- // Configure services + // First pass: Extract actor namespace configs. + for (auto serviceConf: config.getServices()) { + kj::StringPtr name = serviceConf.getName(); + kj::HashMap serviceActorConfigs; + + if (serviceConf.isWorker()) { + auto workerConf = serviceConf.getWorker(); + bool hadDurable = false; + for (auto ns: workerConf.getDurableObjectNamespaces()) { + switch (ns.which()) { + case config::Worker::DurableObjectNamespace::UNIQUE_KEY: + hadDurable = true; + serviceActorConfigs.insert(kj::str(ns.getClassName()), + Durable { kj::str(ns.getUniqueKey()) }); + continue; + case config::Worker::DurableObjectNamespace::EPHEMERAL_LOCAL: + serviceActorConfigs.insert(kj::str(ns.getClassName()), Ephemeral {}); + continue; + } + reportConfigError(kj::str( + "Encountered unknown DurableObjectNamespace type in service \"", name, + "\", class \"", ns.getClassName(), "\". Was the config compiled with a newer version " + "of the schema?")); + } + + switch (workerConf.getDurableObjectStorage().which()) { + case config::Worker::DurableObjectStorage::NONE: + if (hadDurable) { + reportConfigError(kj::str( + "Worker service \"", name, "\" implements durable object classes but has " + "`durableObjectStorage` set to `none`.")); + } + goto validDurableObjectStorage; + case config::Worker::DurableObjectStorage::IN_MEMORY: + goto validDurableObjectStorage; + } + reportConfigError(kj::str( + "Encountered unknown durableObjectStorage type in service \"", name, + "\". Was the config compiled with a newer version of the schema?")); + validDurableObjectStorage: + ; + } + + actorConfigs.upsert(kj::str(name), kj::mv(serviceActorConfigs), [&](auto&&...) { + reportConfigError(kj::str("Config defines multiple services named \"", name, "\".")); + }); + } + + // Second pass: Build services. for (auto serviceConf: config.getServices()) { kj::StringPtr name = serviceConf.getName(); auto service = makeService(serviceConf, headerTableBuilder); @@ -1637,7 +1947,7 @@ kj::Promise Server::run(jsg::V8System& v8System, config::Config::Reader co }; }); - // Now that all services are constructed, we can tell them to cross-link to each other. + // Third pass: Cross-link services. for (auto& service: services) { service.value->link(); } diff --git a/src/workerd/server/server.h b/src/workerd/server/server.h index e8a3e37eaaf..a387843f921 100644 --- a/src/workerd/server/server.h +++ b/src/workerd/server/server.h @@ -72,6 +72,15 @@ class Server: private kj::TaskSet::ErrorHandler { class Service; kj::Own invalidConfigServiceSingleton; + struct Durable { kj::String uniqueKey; }; + struct Ephemeral {}; + using ActorConfig = kj::OneOf; + + kj::HashMap> actorConfigs; + // Information about all known actor namespaces. Maps serviceName -> className -> config. + // This needs to be populated in advance of constructing any services, in order to be able to + // correctly construct dependent services. + kj::HashMap> services; kj::Own> fatalFulfiller; diff --git a/src/workerd/server/workerd-api.c++ b/src/workerd/server/workerd-api.c++ index d55554e9ebd..4c3d4fab2ac 100644 --- a/src/workerd/server/workerd-api.c++ +++ b/src/workerd/server/workerd-api.c++ @@ -17,6 +17,9 @@ #include #include #include +#include +#include +#include namespace workerd::server { @@ -307,6 +310,110 @@ kj::Own WorkerdApiIsolate::compileModules( return modules; } +class ActorIdFactoryImpl final: public ActorIdFactory { +public: + ActorIdFactoryImpl(kj::StringPtr uniqueKey) { + KJ_ASSERT(SHA256(uniqueKey.asBytes().begin(), uniqueKey.size(), key) == key); + } + + class ActorIdImpl final: public ActorId { + public: + ActorIdImpl(const kj::byte idParam[SHA256_DIGEST_LENGTH], kj::Maybe name) + : name(kj::mv(name)) { + memcpy(id, idParam, sizeof(id)); + } + + kj::String toString() const override { + return kj::encodeHex(kj::ArrayPtr(id)); + } + kj::Maybe getName() const override { + return name; + } + bool equals(const ActorId& other) const override { + return memcmp(id, kj::downcast(other).id, sizeof(id)) == 0; + } + kj::Own clone() const override { + return kj::heap(id, name.map([](kj::StringPtr str) { return kj::str(str); })); + } + + private: + kj::byte id[SHA256_DIGEST_LENGTH]; + kj::Maybe name; + }; + + kj::Own newUniqueId(kj::Maybe jurisdiction) override { + JSG_REQUIRE(jurisdiction == nullptr, Error, + "Jurisdiction restrictions are not implemented in workerd."); + + // We want to randomly-generate the first 16 bytes, then HMAC those to produce the latter + // 16 bytes. But the HMAC will produce 32 bytes, so we're only taking a prefix of it. We'll + // allocate a single array big enough to output the HMAC as a suffix, which will then get + // truncated. + kj::byte id[BASE_LENGTH + SHA256_DIGEST_LENGTH]; + + if (isPredictableModeForTest()) { + memcpy(id, &counter, sizeof(counter)); + memset(id + sizeof(counter), 0, BASE_LENGTH - sizeof(counter)); + ++counter; + } else { + KJ_ASSERT(RAND_bytes(id, BASE_LENGTH) == 1); + } + + computeMac(id); + return kj::heap(id, nullptr); + } + + kj::Own idFromName(kj::String name) override { + kj::byte id[BASE_LENGTH + SHA256_DIGEST_LENGTH]; + + // Compute the first half of the ID by HMACing the name itself. We're using HMAC as a keyed + // hash here, not actually for authentication, but it works. + uint len = SHA256_DIGEST_LENGTH; + KJ_ASSERT(HMAC(EVP_sha256(), key, sizeof(key), name.asBytes().begin(), name.size(), id, &len) + == id); + KJ_ASSERT(len == SHA256_DIGEST_LENGTH); + + computeMac(id); + return kj::heap(id, kj::mv(name)); + } + + kj::Own idFromString(kj::String str) override { + auto decoded = kj::decodeHex(str); + JSG_REQUIRE(str.size() == SHA256_DIGEST_LENGTH * 2 && !decoded.hadErrors && + decoded.size() == SHA256_DIGEST_LENGTH, + TypeError, "Invalid Durable Object ID: must be 64 hex digits"); + + kj::byte id[BASE_LENGTH + SHA256_DIGEST_LENGTH]; + memcpy(id, decoded.begin(), BASE_LENGTH); + computeMac(id); + + // Verify that the computed mac matches the input. + JSG_REQUIRE(memcmp(id + BASE_LENGTH, decoded.begin() + BASE_LENGTH, + decoded.size() - BASE_LENGTH) == 0, + TypeError, "Durable Object ID is not valid for this namespace."); + + return kj::heap(id, nullptr); + } + +private: + kj::byte key[SHA256_DIGEST_LENGTH]; + + uint64_t counter = 0; // only used in predictable mode + + static constexpr size_t BASE_LENGTH = SHA256_DIGEST_LENGTH / 2; + void computeMac(kj::byte id[BASE_LENGTH + SHA256_DIGEST_LENGTH]) { + // Given that the first `BASE_LENGTH` bytes of `id` are filled in, compute the second half + // of the ID by HMACing the first half. The id must be in a buffer large enough to store the + // first half of the ID plus a full HMAC, even though only a prefix of the HMAC becomes part + // of the final ID. + + kj::byte* hmacOut = id + BASE_LENGTH; + uint len = SHA256_DIGEST_LENGTH; + KJ_ASSERT(HMAC(EVP_sha256(), key, sizeof(key), id, BASE_LENGTH, hmacOut, &len) == hmacOut); + KJ_ASSERT(len == SHA256_DIGEST_LENGTH); + } +}; + void WorkerdApiIsolate::compileGlobals( jsg::Lock& lockParam, kj::ArrayPtr globals, v8::Local target, @@ -378,6 +485,15 @@ void WorkerdApiIsolate::compileGlobals( value = lock.wrap(context, kj::mv(importedKey)); } + KJ_CASE_ONEOF(ns, Global::EphemeralActorNamespace) { + value = lock.wrap(context, jsg::alloc(ns.actorChannel)); + } + + KJ_CASE_ONEOF(ns, Global::DurableActorNamespace) { + value = lock.wrap(context, jsg::alloc(ns.actorChannel, + kj::heap(ns.uniqueKey))); + } + KJ_CASE_ONEOF(text, kj::String) { value = lock.wrap(context, kj::mv(text)); } @@ -422,6 +538,12 @@ WorkerdApiIsolate::Global WorkerdApiIsolate::Global::clone() const { KJ_CASE_ONEOF(key, Global::CryptoKey) { result.value = key.clone(); } + KJ_CASE_ONEOF(ns, Global::EphemeralActorNamespace) { + result.value = ns.clone(); + } + KJ_CASE_ONEOF(ns, Global::DurableActorNamespace) { + result.value = ns.clone(); + } KJ_CASE_ONEOF(text, kj::String) { result.value = kj::str(text); } diff --git a/src/workerd/server/workerd-api.h b/src/workerd/server/workerd-api.h index 36b52c59e81..2e32fecbacf 100644 --- a/src/workerd/server/workerd-api.h +++ b/src/workerd/server/workerd-api.h @@ -100,9 +100,24 @@ class WorkerdApiIsolate final: public Worker::ApiIsolate { }; } }; + struct EphemeralActorNamespace { + uint actorChannel; + + EphemeralActorNamespace clone() const { + return *this; + } + }; + struct DurableActorNamespace { + uint actorChannel; + kj::StringPtr uniqueKey; + + DurableActorNamespace clone() const { + return *this; + } + }; kj::String name; - kj::OneOf> value; + kj::OneOf> value; Global clone() const; }; From 0bdea57f1dfae32e8be67fc57b35f6d065db84a3 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Fri, 23 Sep 2022 16:32:57 -0500 Subject: [PATCH 5/5] Update README to indicate non-durable Durable Objects are supported. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ea3f881e268..5e799af5faa 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ As of this writing, some major features are missing which we intend to fix short * **Wrangler/Miniflare integration** is in progress. The [Wrangler CLI tool](https://developers.cloudflare.com/workers/wrangler/) and [Miniflare](https://miniflare.dev/) will soon support local testing using `workerd` (replacing the previous simulated environment on top of Node). Wrangler should also support generating `workerd` configuration directly from a Wrangler project. * **Multi-threading** is not implemented. `workerd` runs in a single-threaded event loop. For now, to utilize multiple cores, we suggest running multiple instances of `workerd` and balancing load across them. We will likely add some built-in functionality for this in the near future. * **Performance tuning** has not been done yet, and there is low-hanging fruit here. `workerd` performs decently as-is, but not spectacularly. Experiments suggest we can roughly double performance on a "hello world" load test with some tuning of compiler optimization flags and memory allocators. -* **Durable Objects** are not supported yet. We will add support for in-memory Durable Objects shortly, to allow for local testing of DO-based applications. Durable Objects that are actually durable, or distributed across multiple machines, are a longer-term project. Cloudflare's internal implementation of this is heavily tied to the specifics of Cloudflare's network, so a new implementation needs to be developed for public consumption. +* **Durable Objects** are currently supported only in a mode that uses in-memory storage -- i.e., not actually "durable". This is useful for local testing of DO-based apps, but not for production. Durable Objects that are actually durable, or distributed across multiple machines, are a longer-term project. Cloudflare's internal implementation of this is heavily tied to the specifics of Cloudflare's network, so a new implementation needs to be developed for public consumption. * **Cache API** emulation is not implemented yet. * **Cron trigger** emulation is not supported yet. We need to figure out how, exactly, this should work in the first place. Typically if you have a cluster of machines, you only want a cron event to run on one of the machines, so some sort of coordination or external driver is needed. * **Parameterized workers** are not implemented yet. This is a new feature specified in the config schema, which doesn't have any precedent on Cloudflare.