forked from mlrun/mlrun
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
375 lines (328 loc) · 13.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
# Copyright 2018 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import concurrent.futures
import time
import traceback
import uuid
import fastapi
import fastapi.concurrency
import uvicorn
import uvicorn.protocols.utils
from fastapi.exception_handlers import http_exception_handler
import mlrun.api.schemas
import mlrun.api.utils.clients.chief
import mlrun.errors
import mlrun.utils
import mlrun.utils.version
from mlrun.api.api.api import api_router
from mlrun.api.db.session import close_session, create_session
from mlrun.api.initial_data import init_data
from mlrun.api.utils.periodic import (
cancel_all_periodic_functions,
cancel_periodic_function,
run_function_periodically,
)
from mlrun.api.utils.singletons.db import get_db, initialize_db
from mlrun.api.utils.singletons.logs_dir import initialize_logs_dir
from mlrun.api.utils.singletons.project_member import (
get_project_member,
initialize_project_member,
)
from mlrun.api.utils.singletons.scheduler import get_scheduler, initialize_scheduler
from mlrun.config import config
from mlrun.errors import err_to_str
from mlrun.k8s_utils import get_k8s_helper
from mlrun.runtimes import RuntimeKinds, get_runtime_handler
from mlrun.utils import logger
API_PREFIX = "/api"
BASE_VERSIONED_API_PREFIX = f"{API_PREFIX}/v1"
app = fastapi.FastAPI(
title="MLRun",
description="Machine Learning automation and tracking",
version=config.version,
debug=config.httpdb.debug,
# adding /api prefix
openapi_url=f"{BASE_VERSIONED_API_PREFIX}/openapi.json",
docs_url=f"{BASE_VERSIONED_API_PREFIX}/docs",
redoc_url=f"{BASE_VERSIONED_API_PREFIX}/redoc",
default_response_class=fastapi.responses.ORJSONResponse,
)
app.include_router(api_router, prefix=BASE_VERSIONED_API_PREFIX)
# This is for backward compatibility, that is why we still leave it here but not include it in the schema
# so new users won't use the old un-versioned api
# TODO: remove when 0.9.x versions are no longer relevant
app.include_router(api_router, prefix=API_PREFIX, include_in_schema=False)
@app.exception_handler(Exception)
async def generic_error_handler(request: fastapi.Request, exc: Exception):
error_message = repr(exc)
return await fastapi.exception_handlers.http_exception_handler(
# we have no specific knowledge on what was the exception and what status code fits so we simply use 500
# This handler is mainly to put the error message in the right place in the body so the client will be able to
# show it
# TODO: 0.6.6 is the last version expecting the error details to be under reason, when it's no longer a relevant
# version can be changed to detail=error_message
request,
fastapi.HTTPException(status_code=500, detail={"reason": error_message}),
)
@app.exception_handler(mlrun.errors.MLRunHTTPStatusError)
async def http_status_error_handler(
request: fastapi.Request, exc: mlrun.errors.MLRunHTTPStatusError
):
status_code = exc.response.status_code
error_message = repr(exc)
logger.warning(
"Request handling returned error status",
error_message=error_message,
status_code=status_code,
traceback=traceback.format_exc(),
)
# TODO: 0.6.6 is the last version expecting the error details to be under reason, when it's no longer a relevant
# version can be changed to detail=error_message
return await http_exception_handler(
request,
fastapi.HTTPException(
status_code=status_code, detail={"reason": error_message}
),
)
def get_client_address(scope):
# uvicorn expects this to be a tuple while starlette test client sets it to be a list
if isinstance(scope.get("client"), list):
scope["client"] = tuple(scope.get("client"))
return uvicorn.protocols.utils.get_client_addr(scope)
@app.middleware("http")
async def log_request_response(request: fastapi.Request, call_next):
request_id = str(uuid.uuid4())
silent_logging_paths = [
"healthz",
]
path_with_query_string = uvicorn.protocols.utils.get_path_with_query_string(
request.scope
)
start_time = time.perf_counter_ns()
if not any(
silent_logging_path in path_with_query_string
for silent_logging_path in silent_logging_paths
):
logger.debug(
"Received request",
method=request.method,
client_address=get_client_address(request.scope),
http_version=request.scope["http_version"],
request_id=request_id,
uri=path_with_query_string,
)
try:
response = await call_next(request)
except Exception as exc:
logger.warning(
"Request handling failed. Sending response",
# User middleware (like this one) runs after the exception handling middleware, the only thing running after
# it is Starletter's ServerErrorMiddleware which is responsible for catching any un-handled exception
# and transforming it to 500 response. therefore we can statically assign status code to 500
status_code=500,
request_id=request_id,
uri=path_with_query_string,
method=request.method,
exc=exc,
traceback=traceback.format_exc(),
)
raise
else:
# convert from nano seconds to milliseconds
elapsed_time_in_ms = (time.perf_counter_ns() - start_time) / 1000 / 1000
if not any(
silent_logging_path in path_with_query_string
for silent_logging_path in silent_logging_paths
):
logger.debug(
"Sending response",
status_code=response.status_code,
request_id=request_id,
elapsed_time=elapsed_time_in_ms,
uri=path_with_query_string,
method=request.method,
)
return response
@app.on_event("startup")
async def startup_event():
logger.info(
"configuration dump",
dumped_config=config.dump_yaml(),
version=mlrun.utils.version.Version().get(),
)
loop = asyncio.get_running_loop()
loop.set_default_executor(
concurrent.futures.ThreadPoolExecutor(
max_workers=int(config.httpdb.max_workers)
)
)
initialize_logs_dir()
initialize_db()
if (
config.httpdb.clusterization.worker.sync_with_chief.mode
== mlrun.api.schemas.WaitForChiefToReachOnlineStateFeatureFlag.enabled
and config.httpdb.clusterization.role
== mlrun.api.schemas.ClusterizationRole.worker
):
_start_chief_clusterization_spec_sync_loop()
if config.httpdb.state == mlrun.api.schemas.APIStates.online:
await move_api_to_online()
@app.on_event("shutdown")
async def shutdown_event():
if get_project_member():
get_project_member().shutdown()
cancel_all_periodic_functions()
if get_scheduler():
await get_scheduler().stop()
async def move_api_to_online():
logger.info("Moving api to online")
await initialize_scheduler()
# In general it makes more sense to initialize the project member before the scheduler but in 1.1.0 in follower
# we've added the full sync on the project member initialization (see code there for details) which might delete
# projects which requires the scheduler to be set
initialize_project_member()
# maintenance periodic functions should only run on the chief instance
if config.httpdb.clusterization.role == mlrun.api.schemas.ClusterizationRole.chief:
# runs cleanup/monitoring is not needed if we're not inside kubernetes cluster
if get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster():
_start_periodic_cleanup()
_start_periodic_runs_monitoring()
def _start_periodic_cleanup():
interval = int(config.runtimes_cleanup_interval)
if interval > 0:
logger.info("Starting periodic runtimes cleanup", interval=interval)
run_function_periodically(
interval, _cleanup_runtimes.__name__, False, _cleanup_runtimes
)
def _start_periodic_runs_monitoring():
interval = int(config.runs_monitoring_interval)
if interval > 0:
logger.info("Starting periodic runs monitoring", interval=interval)
run_function_periodically(
interval, _monitor_runs.__name__, False, _monitor_runs
)
def _start_chief_clusterization_spec_sync_loop():
interval = int(config.httpdb.clusterization.worker.sync_with_chief.interval)
if interval > 0:
logger.info("Starting chief clusterization spec sync loop", interval=interval)
run_function_periodically(
interval,
_synchronize_with_chief_clusterization_spec.__name__,
False,
_synchronize_with_chief_clusterization_spec,
)
async def _synchronize_with_chief_clusterization_spec():
# sanity
# if we are still in the periodic function and the worker has reached the terminal state, then cancel it
if config.httpdb.state in mlrun.api.schemas.APIStates.terminal_states():
cancel_periodic_function(_synchronize_with_chief_clusterization_spec.__name__)
try:
chief_client = mlrun.api.utils.clients.chief.Client()
clusterization_spec = chief_client.get_clusterization_spec(
return_fastapi_response=False, raise_on_failure=True
)
except Exception as exc:
logger.debug(
"Failed receiving clusterization spec",
exc=err_to_str(exc),
traceback=traceback.format_exc(),
)
else:
await _align_worker_state_with_chief_state(clusterization_spec)
async def _align_worker_state_with_chief_state(
clusterization_spec: mlrun.api.schemas.ClusterizationSpec,
):
chief_state = clusterization_spec.chief_api_state
if not chief_state:
logger.warning("Chief did not return any state")
return
if chief_state not in mlrun.api.schemas.APIStates.terminal_states():
logger.debug(
"Chief did not reach online state yet, will retry after sync interval",
interval=config.httpdb.clusterization.worker.sync_with_chief.interval,
chief_state=chief_state,
)
# we want the worker to be aligned with chief state
config.httpdb.state = chief_state
return
if chief_state == mlrun.api.schemas.APIStates.online:
logger.info("Chief reached online state! Switching worker state to online")
await move_api_to_online()
logger.info("Worker state reached online")
else:
logger.info(
"Chief state is terminal, canceling worker periodic chief clusterization spec pulling",
state=config.httpdb.state,
)
config.httpdb.state = chief_state
# if reached terminal state we cancel the periodic function
# assumption: we can't get out of a terminal api state, so no need to continue pulling when reached one
cancel_periodic_function(_synchronize_with_chief_clusterization_spec.__name__)
def _monitor_runs():
db_session = create_session()
try:
for kind in RuntimeKinds.runtime_with_handlers():
try:
runtime_handler = get_runtime_handler(kind)
runtime_handler.monitor_runs(get_db(), db_session)
except Exception as exc:
logger.warning(
"Failed monitoring runs. Ignoring",
exc=err_to_str(exc),
kind=kind,
)
finally:
close_session(db_session)
def _cleanup_runtimes():
db_session = create_session()
try:
for kind in RuntimeKinds.runtime_with_handlers():
try:
runtime_handler = get_runtime_handler(kind)
runtime_handler.delete_resources(get_db(), db_session)
except Exception as exc:
logger.warning(
"Failed deleting resources. Ignoring",
exc=err_to_str(exc),
kind=kind,
)
finally:
close_session(db_session)
def main():
if config.httpdb.clusterization.role == mlrun.api.schemas.ClusterizationRole.chief:
init_data()
elif (
config.httpdb.clusterization.worker.sync_with_chief.mode
== mlrun.api.schemas.WaitForChiefToReachOnlineStateFeatureFlag.enabled
and config.httpdb.clusterization.role
== mlrun.api.schemas.ClusterizationRole.worker
):
# we set this state to mark the phase between the startup of the instance until we able to pull the chief state
config.httpdb.state = mlrun.api.schemas.APIStates.waiting_for_chief
logger.info(
"Starting API server",
port=config.httpdb.port,
debug=config.httpdb.debug,
)
uvicorn.run(
"mlrun.api.main:app",
host="0.0.0.0",
port=config.httpdb.port,
access_log=False,
timeout_keep_alive=config.httpdb.http_connection_timeout_keep_alive,
)
if __name__ == "__main__":
main()