Skip to content

Commit

Permalink
feat: Add trycatch boundary and allow turning off local uploads. #982
Browse files Browse the repository at this point in the history
  • Loading branch information
mturoci committed Jan 10, 2023
1 parent 7b7799f commit 2d54594
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 25 deletions.
1 change: 1 addition & 0 deletions py/Makefile
Expand Up @@ -37,6 +37,7 @@ docs: ## Build API docs
test:
./venv/bin/python -m tests
echo "Testing using BASE_URL" && H2O_WAVE_BASE_URL="/foo/" ./venv/bin/python -m tests
echo "Testing using LOCAL UPLOAD" && H2O_WAVE_WAVED_DIR=".." ./venv/bin/python -m tests

purge: ## Purge previous build
rm -rf build dist h2o_wave.egg-info
Expand Down
70 changes: 55 additions & 15 deletions py/h2o_wave/core.py
Expand Up @@ -662,14 +662,42 @@ def upload(self, files: List[str]) -> List[str]:
Returns:
A list of remote URLs for the uploaded files, in order.
"""
upload_files = []
for f in files:
if not os.path.isfile(f):
raise ValueError(f'{f} is not a file.')

waved_dir = os.environ.get('H2O_WAVE_WAVED_DIR', None)
data_dir = os.environ.get('H2O_WAVE_DATA_DIR', 'data')
skip_local_upload = os.environ.get('H2O_WAVE_SKIP_LOCAL_UPLOAD', None)

# If we know the path of waved and running app on the same machine,
# we can simply copy the files instead of making an HTTP request.
if not skip_local_upload and waved_dir and _is_loopback_address(_config.hub_address):
try:
is_windows = 'Windows' in platform.system()
cp_command = 'xcopy' if is_windows else 'cp'
uploaded_files = []
for f in files:
uuid = str(uuid4())
dst = os.path.join(waved_dir, data_dir, 'f', uuid, os.path.basename(f))
os.makedirs(os.path.dirname(dst), exist_ok=True)
p = subprocess.Popen([cp_command, f, dst, '/K/O/X' if is_windows else ''], stderr=subprocess.PIPE)
_, err = p.communicate()
if err:
raise ValueError(err.decode())
uploaded_files.append(f'/_f/{uuid}/{os.path.basename(f)}')
return uploaded_files
except:
pass

uploaded_files = []
file_handles: List[BufferedReader] = []
for f in files:
file_handle = open(f, 'rb')
upload_files.append(('files', (os.path.basename(f), file_handle)))
uploaded_files.append(('files', (os.path.basename(f), file_handle)))
file_handles.append(file_handle)

res = self._http.post(f'{_config.hub_address}_f/', files=upload_files)
res = self._http.post(f'{_config.hub_address}_f/', files=uploaded_files)

for h in file_handles:
h.close()
Expand Down Expand Up @@ -872,22 +900,33 @@ async def upload(self, files: List[str]) -> List[str]:
Returns:
A list of remote URLs for the uploaded files, in order.
"""
for f in files:
if not os.path.isfile(f):
raise ValueError(f'{f} is not a file.')

waved_dir = os.environ.get('H2O_WAVE_WAVED_DIR', None)
data_dir = os.environ.get('H2O_WAVE_DATA_DIR', 'data')
skip_local_upload = os.environ.get('H2O_WAVE_SKIP_LOCAL_UPLOAD', None)

# If we know the path of waved and running app on the same machine,
# we can simply copy the files instead of making an HTTP request.
if waved_dir and _is_loopback_address(_config.hub_address):
is_windows = 'Windows' in platform.system()
cp_command = 'xcopy' if is_windows else 'cp'
uuid = str(uuid4())
for f in files:
if not os.path.isfile(f):
raise ValueError(f'{f} is not a file.')
dst = os.path.join(waved_dir, data_dir, 'f', uuid, os.path.basename(f))
os.makedirs(os.path.dirname(dst), exist_ok=True)
p = subprocess.Popen([cp_command, f, dst, '/K/O/X' if is_windows else ''])
p.communicate()
return [f'/_f/{uuid}/{os.path.basename(f)}' for f in files]
if not skip_local_upload and waved_dir and _is_loopback_address(_config.hub_address):
try:
is_windows = 'Windows' in platform.system()
cp_command = 'xcopy' if is_windows else 'cp'
uploaded_files = []
for f in files:
uuid = str(uuid4())
dst = os.path.join(waved_dir, data_dir, 'f', uuid, os.path.basename(f))
os.makedirs(os.path.dirname(dst), exist_ok=True)
p = subprocess.Popen([cp_command, f, dst, '/K/O/X' if is_windows else ''], stderr=subprocess.PIPE)
_, err = p.communicate()
if err:
raise ValueError(err.decode())
uploaded_files.append(f'/_f/{uuid}/{os.path.basename(f)}')
return uploaded_files
except:
pass

upload_files = []
file_handles: List[BufferedReader] = []
Expand Down Expand Up @@ -918,6 +957,7 @@ async def download(self, url: str, path: str) -> str:
path = os.path.abspath(path)
# If path is a directory, get basename from url
filepath = os.path.join(path, os.path.basename(url)) if os.path.isdir(path) else path

async with self._http.stream('GET', f'{_config.hub_host_address}{url}') as res:
if res.status_code != 200:
await res.aread()
Expand Down
6 changes: 0 additions & 6 deletions py/tests/test_python_server.py
Expand Up @@ -330,7 +330,6 @@ def test_cyc_buf_write(self):
i=2,
))))


def test_proxy(self):
# waved -proxy must be set
url = 'https://wave.h2o.ai'
Expand All @@ -342,7 +341,6 @@ def test_proxy(self):
assert result.code == 400
assert len(result.headers) > 0


def test_file_server(self):
f1 = 'temp_file1.txt'
with open(f1, 'w') as f:
Expand All @@ -356,14 +354,12 @@ def test_file_server(self):
os.remove(f2)
assert s1 == s2


def test_public_dir(self):
p = site.download(f'{base_url}assets/brand/h2o.svg', 'h2o.svg')
svg = read_file(p)
os.remove(p)
assert svg.index('<svg') == 0


def test_cache(self):
d1 = dict(foo='bar', qux=42)
site.cache.set('test', 'data', d1)
Expand All @@ -375,7 +371,6 @@ def test_cache(self):
assert d2['foo'] == d1['foo']
assert d2['qux'] == d1['qux']


def test_multipart_server(self):
file_handle = open('../assets/brand/wave.svg', 'rb')
p = site.uplink('test_stream', 'image/svg+xml', file_handle)
Expand All @@ -390,7 +385,6 @@ def test_upload_dir(self):
os.remove(download_path)
assert len(txt) > 0


def test_deleting_files(self):
upload_path, = site.upload([os.path.join('tests', 'test_folder', 'test.txt')])
res = httpx.get(f'http://localhost:10101{upload_path}')
Expand Down
6 changes: 2 additions & 4 deletions py/tests/test_python_server_async.py
Expand Up @@ -19,6 +19,7 @@

from .utils import read_file


# TODO: Add cleanup (site.unload) to tests that upload files.
class TestPythonServerAsync(unittest.IsolatedAsyncioTestCase):
def __init__(self, methodName: str = ...) -> None:
Expand All @@ -38,15 +39,13 @@ async def test_file_server(self):
os.remove(f2)
assert s1 == s2


async def test_public_dir(self):
base_url = os.getenv('H2O_WAVE_BASE_URL', '/')
p = await self.site.download(f'{base_url}assets/brand/h2o.svg', 'h2o.svg')
svg = read_file(p)
os.remove(p)
assert svg.index('<svg') == 0


async def test_cache(self):
d1 = dict(foo='bar', qux=42)
await self.site.cache.set('test', 'data', d1)
Expand All @@ -58,7 +57,6 @@ async def test_cache(self):
assert d2['foo'] == d1['foo']
assert d2['qux'] == d1['qux']


async def test_multipart_server(self):
file_handle = open('../assets/brand/wave.svg', 'rb')
p = await self.site.uplink('test_stream', 'image/svg+xml', file_handle)
Expand All @@ -79,4 +77,4 @@ async def test_deleting_files(self):
assert res.status_code == 200
await self.site.unload(upload_path)
res = httpx.get(f'http://localhost:10101{upload_path}')
assert res.status_code == 404
assert res.status_code == 404

0 comments on commit 2d54594

Please sign in to comment.