Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow client-provided ID when creating a task run #13276

Merged
merged 8 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,7 @@ async def create_task_run(
task: "TaskObject[P, R]",
flow_run_id: Optional[UUID],
dynamic_key: str,
id: Optional[UUID] = None,
name: Optional[str] = None,
extra_tags: Optional[Iterable[str]] = None,
state: Optional[prefect.states.State[R]] = None,
Expand All @@ -2192,6 +2193,8 @@ async def create_task_run(
task: The Task to run
flow_run_id: The flow run id with which to associate the task run
dynamic_key: A key unique to this particular run of a Task within the flow
id: An optional ID for the task run. If not provided, one will be generated
server-side.
name: An optional name for the task run
extra_tags: an optional list of extra tags to apply to the task run in
addition to `task.tags`
Expand All @@ -2208,6 +2211,7 @@ async def create_task_run(
state = prefect.states.Pending()

task_run_data = TaskRunCreate(
id=id,
name=name,
flow_run_id=flow_run_id,
task_key=task.task_key,
Expand All @@ -2222,10 +2226,9 @@ async def create_task_run(
state=state.to_state_create(),
task_inputs=task_inputs or {},
)
content = task_run_data.json(exclude={"id"} if id is None else None)

response = await self._client.post(
"/task_runs/", json=task_run_data.dict(json_compatible=True)
)
response = await self._client.post("/task_runs/", content=content)
return TaskRun.parse_obj(response.json())

async def read_task_run(self, task_run_id: UUID) -> TaskRun:
Expand Down Expand Up @@ -3810,6 +3813,7 @@ def create_task_run(
task: "TaskObject[P, R]",
flow_run_id: Optional[UUID],
dynamic_key: str,
id: Optional[UUID] = None,
name: Optional[str] = None,
extra_tags: Optional[Iterable[str]] = None,
state: Optional[prefect.states.State[R]] = None,
Expand All @@ -3833,6 +3837,8 @@ def create_task_run(
task: The Task to run
flow_run_id: The flow run id with which to associate the task run
dynamic_key: A key unique to this particular run of a Task within the flow
id: An optional ID for the task run. If not provided, one will be generated
server-side.
name: An optional name for the task run
extra_tags: an optional list of extra tags to apply to the task run in
addition to `task.tags`
Expand All @@ -3849,6 +3855,7 @@ def create_task_run(
state = prefect.states.Pending()

task_run_data = TaskRunCreate(
id=id,
name=name,
flow_run_id=flow_run_id,
task_key=task.task_key,
Expand All @@ -3864,9 +3871,9 @@ def create_task_run(
task_inputs=task_inputs or {},
)

response = self._client.post(
"/task_runs/", json=task_run_data.dict(json_compatible=True)
)
content = task_run_data.json(exclude={"id"} if id is None else None)

response = self._client.post("/task_runs/", content=content)
return TaskRun.parse_obj(response.json())

def read_task_run(self, task_run_id: UUID) -> TaskRun:
Expand Down
1 change: 1 addition & 0 deletions src/prefect/client/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class FlowRunUpdate(ActionBaseModel):
class TaskRunCreate(ActionBaseModel):
"""Data used by the Prefect REST API to create a task run"""

id: Optional[UUID] = Field(None, description="The ID to assign to the task run")
# TaskRunCreate states must be provided as StateCreate objects
state: Optional[StateCreate] = Field(
default=None, description="The state of the task run to create"
Expand Down
20 changes: 16 additions & 4 deletions src/prefect/new_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Union,
cast,
)
from uuid import UUID

import pendulum
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -275,7 +276,9 @@ def enter_run_context(self, client: Optional[SyncPrefectClient] = None):
self.logger = current_logger

@contextmanager
def start(self) -> Generator["TaskRunEngine", Any, Any]:
def start(
self, task_run_id: Optional[UUID] = None
) -> Generator["TaskRunEngine", Any, Any]:
"""
Enters a client context and creates a task run if needed.
"""
Expand All @@ -286,6 +289,7 @@ def start(self) -> Generator["TaskRunEngine", Any, Any]:
if not self.task_run:
self.task_run = run_sync(
self.task.create_run(
id=task_run_id,
client=client,
parameters=self.parameters,
flow_run_context=FlowRunContext.get(),
Expand Down Expand Up @@ -324,15 +328,20 @@ def is_pending(self) -> bool:

def run_task_sync(
task: Task[P, R],
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
engine = TaskRunEngine[P, R](
task=task,
parameters=parameters,
task_run=task_run,
)

# This is a context manager that keeps track of the run of the task run.
with engine.start() as run:
with engine.start(task_run_id=task_run_id) as run:
run.begin_run()

while run.is_running():
Expand Down Expand Up @@ -362,6 +371,7 @@ def run_task_sync(

async def run_task_async(
task: Task[P, Coroutine[Any, Any, R]],
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
Expand All @@ -375,7 +385,7 @@ async def run_task_async(
engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)

# This is a context manager that keeps track of the run of the task run.
with engine.start() as run:
with engine.start(task_run_id=task_run_id) as run:
run.begin_run()

while run.is_running():
Expand Down Expand Up @@ -405,13 +415,15 @@ async def run_task_async(

def run_task(
task: Task[P, R],
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
kwargs = dict(
task=task,
task_run_id=task_run_id,
task_run=task_run,
parameters=parameters,
wait_for=wait_for,
Expand Down
5 changes: 4 additions & 1 deletion src/prefect/server/api/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ async def create_task_run(
If no state is provided, the task run will be created in a PENDING state.
"""
# hydrate the input model into a full task run / state model
task_run = schemas.core.TaskRun(**task_run.dict())
task_run_dict = task_run.dict()
if not task_run_dict.get("id"):
task_run_dict.pop("id", None)
task_run = schemas.core.TaskRun(**task_run_dict)

if not task_run.state:
task_run.state = schemas.states.Pending()
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/server/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ def default_scheduled_start_time(cls, values):
class TaskRunCreate(ActionBaseModel):
"""Data used by the Prefect REST API to create a task run"""

id: Optional[UUID] = Field(
default=None,
description="The ID to assign to the task run. If not provided, a random UUID will be generated.",
)
# TaskRunCreate states must be provided as StateCreate objects
state: Optional[StateCreate] = Field(
default=None, description="The state of the task run to create"
Expand Down
8 changes: 5 additions & 3 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
cast,
overload,
)
from uuid import uuid4
from uuid import UUID, uuid4

from typing_extensions import Literal, ParamSpec

Expand Down Expand Up @@ -518,8 +518,9 @@ def with_options(

async def create_run(
self,
client: Optional[Union[PrefectClient, SyncPrefectClient]],
parameters: Dict[str, Any] = None,
client: Union[PrefectClient, SyncPrefectClient],
id: Optional[UUID] = None,
parameters: Optional[Dict[str, Any]] = None,
flow_run_context: Optional[FlowRunContext] = None,
parent_task_run_context: Optional[TaskRunContext] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
Expand Down Expand Up @@ -591,6 +592,7 @@ async def create_run(
else None
),
dynamic_key=str(dynamic_key),
id=id,
state=Pending(),
task_inputs=task_inputs,
extra_tags=TagsContext.get().current_tags,
Expand Down
52 changes: 52 additions & 0 deletions tests/server/api/test_task_runs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from uuid import uuid4

import pendulum
Expand Down Expand Up @@ -145,6 +146,57 @@ async def test_raises_on_jitter_factor_validation(self, flow_run, client, sessio
== "`retry_jitter_factor` must be >= 0."
)

async def test_create_task_run_with_client_provided_id(self, flow_run, client):
client_provided_id = uuid.uuid4()
task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "0",
"id": str(client_provided_id),
}
response = await client.post(
"/task_runs/",
json=task_run_data,
)
assert response.status_code == 201
assert response.json()["id"] == str(client_provided_id)

async def test_create_task_run_with_same_client_provided_id(
self,
flow_run,
client,
):
client_provided_id = uuid.uuid4()
task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "0",
"id": str(client_provided_id),
}
response = await client.post(
"/task_runs/",
json=task_run_data,
)
assert response.status_code == 201
assert response.json()["id"] == str(client_provided_id)

task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "1",
"id": str(client_provided_id),
}

response = await client.post(
"/task_runs/",
json=task_run_data,
)

assert response.status_code == 409


class TestReadTaskRun:
async def test_read_task_run(self, flow_run, task_run, client):
Expand Down
34 changes: 32 additions & 2 deletions tests/test_new_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import time
from typing import List
from unittest.mock import MagicMock
from uuid import UUID
from uuid import UUID, uuid4

import pytest

from prefect import Task, flow, get_run_logger, task
from prefect.client.orchestration import SyncPrefectClient
from prefect.client.orchestration import PrefectClient, SyncPrefectClient
from prefect.client.schemas.objects import StateType
from prefect.context import TaskRunContext, get_run_context
from prefect.exceptions import CrashedRun, MissingResult
Expand Down Expand Up @@ -61,6 +61,36 @@ async def test_client_attr_returns_client_after_starting(self):
engine.client


class TestRunTask:
def test_run_task_with_client_provided_uuid(
self, sync_prefect_client: SyncPrefectClient
):
@task
def foo():
return 42

task_run_id = uuid4()

run_task_sync(foo, task_run_id=task_run_id)

task_run = sync_prefect_client.read_task_run(task_run_id)
assert task_run.id == task_run_id

async def test_run_task_async_with_client_provided_uuid(
self, prefect_client: PrefectClient
):
@task
async def foo():
return 42

task_run_id = uuid4()

await run_task_async(foo, task_run_id=task_run_id)

task_run = await prefect_client.read_task_run(task_run_id)
assert task_run.id == task_run_id


class TestTaskRunsAsync:
async def test_basic(self):
@task
Expand Down