/
test_python_server.py
45 lines (32 loc) · 1.2 KB
/
test_python_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import multiprocessing as mp
from lightning_app.components import Image, Number, PythonServer
from lightning_app.utilities.network import _configure_session, find_free_network_port
class SimpleServer(PythonServer):
def __init__(self, port):
super().__init__(port=port)
self._model = None
def setup(self):
self._model = lambda x: x
def predict(self, data):
return {"prediction": self._model(data.payload)}
def target_fn(port):
image_server = SimpleServer(port=port)
image_server.run()
def test_python_server_component():
port = find_free_network_port()
process = mp.Process(target=target_fn, args=(port,))
process.start()
session = _configure_session()
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
process.terminate()
assert res.json()["prediction"] == "test"
def test_image_sample_data():
data = Image()._get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"] > 100)
def test_number_sample_data():
data = Number()._get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463