Skip to content

Commit

Permalink
Allow client-provided ID when creating a task run (#13276)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed May 13, 2024
1 parent bf0daef commit aabceec
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 16 deletions.
19 changes: 13 additions & 6 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,7 @@ async def create_task_run(
task: "TaskObject[P, R]",
flow_run_id: Optional[UUID],
dynamic_key: str,
id: Optional[UUID] = None,
name: Optional[str] = None,
extra_tags: Optional[Iterable[str]] = None,
state: Optional[prefect.states.State[R]] = None,
Expand All @@ -2192,6 +2193,8 @@ async def create_task_run(
task: The Task to run
flow_run_id: The flow run id with which to associate the task run
dynamic_key: A key unique to this particular run of a Task within the flow
id: An optional ID for the task run. If not provided, one will be generated
server-side.
name: An optional name for the task run
extra_tags: an optional list of extra tags to apply to the task run in
addition to `task.tags`
Expand All @@ -2208,6 +2211,7 @@ async def create_task_run(
state = prefect.states.Pending()

task_run_data = TaskRunCreate(
id=id,
name=name,
flow_run_id=flow_run_id,
task_key=task.task_key,
Expand All @@ -2222,10 +2226,9 @@ async def create_task_run(
state=state.to_state_create(),
task_inputs=task_inputs or {},
)
content = task_run_data.json(exclude={"id"} if id is None else None)

response = await self._client.post(
"/task_runs/", json=task_run_data.dict(json_compatible=True)
)
response = await self._client.post("/task_runs/", content=content)
return TaskRun.parse_obj(response.json())

async def read_task_run(self, task_run_id: UUID) -> TaskRun:
Expand Down Expand Up @@ -3810,6 +3813,7 @@ def create_task_run(
task: "TaskObject[P, R]",
flow_run_id: Optional[UUID],
dynamic_key: str,
id: Optional[UUID] = None,
name: Optional[str] = None,
extra_tags: Optional[Iterable[str]] = None,
state: Optional[prefect.states.State[R]] = None,
Expand All @@ -3833,6 +3837,8 @@ def create_task_run(
task: The Task to run
flow_run_id: The flow run id with which to associate the task run
dynamic_key: A key unique to this particular run of a Task within the flow
id: An optional ID for the task run. If not provided, one will be generated
server-side.
name: An optional name for the task run
extra_tags: an optional list of extra tags to apply to the task run in
addition to `task.tags`
Expand All @@ -3849,6 +3855,7 @@ def create_task_run(
state = prefect.states.Pending()

task_run_data = TaskRunCreate(
id=id,
name=name,
flow_run_id=flow_run_id,
task_key=task.task_key,
Expand All @@ -3864,9 +3871,9 @@ def create_task_run(
task_inputs=task_inputs or {},
)

response = self._client.post(
"/task_runs/", json=task_run_data.dict(json_compatible=True)
)
content = task_run_data.json(exclude={"id"} if id is None else None)

response = self._client.post("/task_runs/", content=content)
return TaskRun.parse_obj(response.json())

def read_task_run(self, task_run_id: UUID) -> TaskRun:
Expand Down
1 change: 1 addition & 0 deletions src/prefect/client/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class FlowRunUpdate(ActionBaseModel):
class TaskRunCreate(ActionBaseModel):
"""Data used by the Prefect REST API to create a task run"""

id: Optional[UUID] = Field(None, description="The ID to assign to the task run")
# TaskRunCreate states must be provided as StateCreate objects
state: Optional[StateCreate] = Field(
default=None, description="The state of the task run to create"
Expand Down
20 changes: 16 additions & 4 deletions src/prefect/new_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Union,
cast,
)
from uuid import UUID

import pendulum
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -275,7 +276,9 @@ def enter_run_context(self, client: Optional[SyncPrefectClient] = None):
self.logger = current_logger

@contextmanager
def start(self) -> Generator["TaskRunEngine", Any, Any]:
def start(
self, task_run_id: Optional[UUID] = None
) -> Generator["TaskRunEngine", Any, Any]:
"""
Enters a client context and creates a task run if needed.
"""
Expand All @@ -286,6 +289,7 @@ def start(self) -> Generator["TaskRunEngine", Any, Any]:
if not self.task_run:
self.task_run = run_sync(
self.task.create_run(
id=task_run_id,
client=client,
parameters=self.parameters,
flow_run_context=FlowRunContext.get(),
Expand Down Expand Up @@ -324,15 +328,20 @@ def is_pending(self) -> bool:

def run_task_sync(
task: Task[P, R],
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
engine = TaskRunEngine[P, R](
task=task,
parameters=parameters,
task_run=task_run,
)

# This is a context manager that keeps track of the run of the task run.
with engine.start() as run:
with engine.start(task_run_id=task_run_id) as run:
run.begin_run()

while run.is_running():
Expand Down Expand Up @@ -362,6 +371,7 @@ def run_task_sync(

async def run_task_async(
task: Task[P, Coroutine[Any, Any, R]],
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
Expand All @@ -375,7 +385,7 @@ async def run_task_async(
engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)

# This is a context manager that keeps track of the run of the task run.
with engine.start() as run:
with engine.start(task_run_id=task_run_id) as run:
run.begin_run()

while run.is_running():
Expand Down Expand Up @@ -405,13 +415,15 @@ async def run_task_async(

def run_task(
task: Task[P, R],
task_run_id: Optional[UUID] = None,
task_run: Optional[TaskRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
return_type: Literal["state", "result"] = "result",
) -> Union[R, State, None]:
kwargs = dict(
task=task,
task_run_id=task_run_id,
task_run=task_run,
parameters=parameters,
wait_for=wait_for,
Expand Down
5 changes: 4 additions & 1 deletion src/prefect/server/api/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ async def create_task_run(
If no state is provided, the task run will be created in a PENDING state.
"""
# hydrate the input model into a full task run / state model
task_run = schemas.core.TaskRun(**task_run.dict())
task_run_dict = task_run.dict()
if not task_run_dict.get("id"):
task_run_dict.pop("id", None)
task_run = schemas.core.TaskRun(**task_run_dict)

if not task_run.state:
task_run.state = schemas.states.Pending()
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/server/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ def default_scheduled_start_time(cls, values):
class TaskRunCreate(ActionBaseModel):
"""Data used by the Prefect REST API to create a task run"""

id: Optional[UUID] = Field(
default=None,
description="The ID to assign to the task run. If not provided, a random UUID will be generated.",
)
# TaskRunCreate states must be provided as StateCreate objects
state: Optional[StateCreate] = Field(
default=None, description="The state of the task run to create"
Expand Down
8 changes: 5 additions & 3 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
cast,
overload,
)
from uuid import uuid4
from uuid import UUID, uuid4

from typing_extensions import Literal, ParamSpec

Expand Down Expand Up @@ -518,8 +518,9 @@ def with_options(

async def create_run(
self,
client: Optional[Union[PrefectClient, SyncPrefectClient]],
parameters: Dict[str, Any] = None,
client: Union[PrefectClient, SyncPrefectClient],
id: Optional[UUID] = None,
parameters: Optional[Dict[str, Any]] = None,
flow_run_context: Optional[FlowRunContext] = None,
parent_task_run_context: Optional[TaskRunContext] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
Expand Down Expand Up @@ -591,6 +592,7 @@ async def create_run(
else None
),
dynamic_key=str(dynamic_key),
id=id,
state=Pending(),
task_inputs=task_inputs,
extra_tags=TagsContext.get().current_tags,
Expand Down
52 changes: 52 additions & 0 deletions tests/server/api/test_task_runs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from uuid import uuid4

import pendulum
Expand Down Expand Up @@ -145,6 +146,57 @@ async def test_raises_on_jitter_factor_validation(self, flow_run, client, sessio
== "`retry_jitter_factor` must be >= 0."
)

async def test_create_task_run_with_client_provided_id(self, flow_run, client):
client_provided_id = uuid.uuid4()
task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "0",
"id": str(client_provided_id),
}
response = await client.post(
"/task_runs/",
json=task_run_data,
)
assert response.status_code == 201
assert response.json()["id"] == str(client_provided_id)

async def test_create_task_run_with_same_client_provided_id(
self,
flow_run,
client,
):
client_provided_id = uuid.uuid4()
task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "0",
"id": str(client_provided_id),
}
response = await client.post(
"/task_runs/",
json=task_run_data,
)
assert response.status_code == 201
assert response.json()["id"] == str(client_provided_id)

task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "1",
"id": str(client_provided_id),
}

response = await client.post(
"/task_runs/",
json=task_run_data,
)

assert response.status_code == 409


class TestReadTaskRun:
async def test_read_task_run(self, flow_run, task_run, client):
Expand Down
34 changes: 32 additions & 2 deletions tests/test_new_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import time
from typing import List
from unittest.mock import MagicMock
from uuid import UUID
from uuid import UUID, uuid4

import pytest

from prefect import Task, flow, get_run_logger, task
from prefect.client.orchestration import SyncPrefectClient
from prefect.client.orchestration import PrefectClient, SyncPrefectClient
from prefect.client.schemas.objects import StateType
from prefect.context import TaskRunContext, get_run_context
from prefect.exceptions import CrashedRun, MissingResult
Expand Down Expand Up @@ -61,6 +61,36 @@ async def test_client_attr_returns_client_after_starting(self):
engine.client


class TestRunTask:
def test_run_task_with_client_provided_uuid(
self, sync_prefect_client: SyncPrefectClient
):
@task
def foo():
return 42

task_run_id = uuid4()

run_task_sync(foo, task_run_id=task_run_id)

task_run = sync_prefect_client.read_task_run(task_run_id)
assert task_run.id == task_run_id

async def test_run_task_async_with_client_provided_uuid(
self, prefect_client: PrefectClient
):
@task
async def foo():
return 42

task_run_id = uuid4()

await run_task_async(foo, task_run_id=task_run_id)

task_run = await prefect_client.read_task_run(task_run_id)
assert task_run.id == task_run_id


class TestTaskRunsAsync:
async def test_basic(self):
@task
Expand Down

0 comments on commit aabceec

Please sign in to comment.