Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gunicorn + Flask App RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method #3176

Open
danerlt opened this issue Mar 25, 2024 · 0 comments

Comments

@danerlt
Copy link

danerlt commented Mar 25, 2024

I have a Flask Web App, which is used for embedding operations. The code of main.py is as follows:

from sentence_transformers import SentenceTransformer
from flask import Flask, request, jsonify
import torch
from pathlib import Path

app = Flask(__name__)

def infer_torch_device():
    has_cuda = torch.cuda.is_available()
    if has_cuda:
        return "cuda"
    return "cpu"

current_path = Path(__file__).parent

model_name = "m3e-base"

model_path = "/data/models/m3e-base"

device = infer_torch_device()

m3e = SentenceTransformer(model_path, device=device)

@app.route('/embed', methods=['POST'])
def embed():
    data = request.get_json()
    query_list = data.get("input", None)
    embeddings = m3e.encode(query_list)
    data = []
    for i, emb in enumerate(embeddings):
        item = {
            "object": "embedding",
            "embedding": emb.astype(float).tolist(),
            "index": i
        }
        data.append(item)
    result = {
        'object': "list",
        "data": data,
        "model": model_name,
        "usage": {
            "prompt_tokens": 11,
            "total_tokens": 11
        }
    }
    return jsonify(result)


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

我使用下面的命令启动gunicorn

gunicorn -w 4 -b 0.0.0.0:5000 main:app

When I call the interface, I get an error ``RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method`

When I specify the -w parameter as 1, it can run normally without reporting an error.

For the same function, I replaced the Flask framework with the FastAPI + uvicorn framework. The main_fastapi.py code is as follows:

import logging

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, PlainTextResponse

logger = logging.getLogger("api")

app = FastAPI()

def infer_torch_device():
    has_cuda = torch.cuda.is_available()
    if has_cuda:
        return "cuda"
    return "cpu"

current_path = Path(__file__).parent

model_name = "m3e-base"

model_path = "/data/models/m3e-base"

device = infer_torch_device()

m3e = SentenceTransformer(model_path, device=device)

@app.get("/health")
def health():
    return PlainTextResponse("ok")


@app.post("/embed")
async def embedding(request: Request):
    data = await request.json()
    query_list = data.get("input", None)
    embeddings = m3e.encode(query_list)
    data = []
    for i, emb in enumerate(embeddings):
        item = {
            "object": "embedding",
            "embedding": emb.astype(float).tolist(),
            "index": i
        }
        data.append(item)
    result = {
        'object': "list",
        "data": data,
        "model": embedding_model.model_name,
        "usage": {
            "prompt_tokens": 11,
            "total_tokens": 11
        }
    }
    return JSONResponse(result)


def main():
    logger.info("服务启动")
    uvicorn.run("main:app", host="0.0.0.0", port=5000, loop="uvloop", log_level="info")


if __name__ == "__main__":
    main()

I start uvicorn using the following command::

uvicorn main_fastapi:app --host 0.0.0.0 --port 5000 --loop uvloop --workers 4

uvicorn can run multiple workers normally without reporting any errors.

Through my investigation, I found that Gunicorn uses os.fork to create child processes. The relevant code is in gunicorn/arbiter.py

def spawn_worker(self):
        self.worker_age += 1
        worker = self.worker_class(self.worker_age, self.pid, self.LISTENERS,
                                   self.app, self.timeout / 2.0,
                                   self.cfg, self.log)
        self.cfg.pre_fork(self, worker)
        pid = os.fork()
        if pid != 0:
            worker.pid = pid
            self.WORKERS[pid] = worker
            return pid

        # Do not inherit the temporary files of other workers
        for sibling in self.WORKERS.values():
            sibling.tmp.close()

        # Process Child
        worker.pid = os.getpid()
        try:
            util._setproctitle("worker [%s]" % self.proc_name)
            self.log.info("Booting worker with pid: %s", worker.pid)
            self.cfg.post_fork(self, worker)
            worker.init_process()
            sys.exit(0)
        except SystemExit:
            raise
        except AppImportError as e:
            self.log.debug("Exception while loading the application",
                           exc_info=True)
            print("%s" % e, file=sys.stderr)
            sys.stderr.flush()
            sys.exit(self.APP_LOAD_ERROR)
        except Exception:
            self.log.exception("Exception in worker process")
            if not worker.booted:
                sys.exit(self.WORKER_BOOT_ERROR)
            sys.exit(-1)
        finally:
            self.log.info("Worker exiting (pid: %s)", worker.pid)
            try:
                worker.tmp.close()
                self.cfg.worker_exit(self, worker)
            except Exception:
                self.log.warning("Exception during worker exit:\n%s",
                                 traceback.format_exc())

The uvicorn framework child processes created by the spawn method. The relevant code is in uvicorn.supervisors.multiprocess.py

class Multiprocess:
    def __init__(
        self,
        config: Config,
        target: Callable[[list[socket] | None], None],
        sockets: list[socket],
    ) -> None:
        self.config = config
        self.target = target
        self.sockets = sockets
        self.processes: list[SpawnProcess] = []
        self.should_exit = threading.Event()
        self.pid = os.getpid()

    def signal_handler(self, sig: int, frame: FrameType | None) -> None:
        """
        A signal handler that is registered with the parent process.
        """
        self.should_exit.set()

    def run(self) -> None:
        self.startup()
        self.should_exit.wait()
        self.shutdown()

    def startup(self) -> None:
        message = f"Started parent process [{str(self.pid)}]"
        color_message = "Started parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
        logger.info(message, extra={"color_message": color_message})

        for sig in HANDLED_SIGNALS:
            signal.signal(sig, self.signal_handler)

        for _idx in range(self.config.workers):
            process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets)
            process.start()
            self.processes.append(process)

    def shutdown(self) -> None:
        for process in self.processes:
            process.terminate()
            process.join()

        message = f"Stopping parent process [{str(self.pid)}]"
        color_message = "Stopping parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True))
        logger.info(message, extra={"color_message": color_message})

The get_subprocess method is defined in uvicorn._subprocess·py:

import multiprocessing
from multiprocessing.context import SpawnProcess

multiprocessing.allow_connection_pickling()
spawn = multiprocessing.get_context("spawn")


def get_subprocess(
    config: Config,
    target: Callable[..., None],
    sockets: list[socket],
) -> SpawnProcess:
    """
    Called in the parent process, to instantiate a new child process instance.
    The child is not yet started at this point.

    * config - The Uvicorn configuration instance.
    * target - A callable that accepts a list of sockets. In practice this will
               be the `Server.run()` method.
    * sockets - A list of sockets to pass to the server. Sockets are bound once
                by the parent process, and then passed to the child processes.
    """
    # We pass across the stdin fileno, and reopen it in the child process.
    # This is required for some debugging environments.
    try:
        stdin_fileno = sys.stdin.fileno()
    # The `sys.stdin` can be `None`, see https://docs.python.org/3/library/sys.html#sys.__stdin__.
    except (AttributeError, OSError):
        stdin_fileno = None

    kwargs = {
        "config": config,
        "target": target,
        "sockets": sockets,
        "stdin_fileno": stdin_fileno,
    }

    return spawn.Process(target=subprocess_started, kwargs=kwargs)

Can Gunicorn support using spawn to create child processes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant