Skip to content

Commit

Permalink
feat!(altair v5, selenium, webdriver): update utils for latest [viz] …
Browse files Browse the repository at this point in the history
…extra dependencies (#94)

* refactor(geopandas io): move geopandas utils, add _gis_enabled feature flag to fix gpd import

* deps(viz): upgrade dependencies to support altair v5, update selenium

* refactor(selenium webdriver): update chrome webdriver setup for latest selenium api

* refactor(altair saving): remove alt_saver, update to altair v5 api

* WIP(altair saving): move export type selection error prior to filepath checking ops

* WIP(webdriver): add handling of chrome webdriver using contextmanager

* refactor(webdriver): use chrome webdriver_context contextmanager for png and svg altair saving

* docs(webdriver): add webdriver context manager docstring

* docs(webdriver): include webdriver contextmanager function in autodocs

* test: add tests for managing selenium webdriver
  • Loading branch information
sqr00t committed Feb 2, 2024
1 parent f71ae55 commit f19531c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 54 deletions.
1 change: 1 addition & 0 deletions docs/source/saving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ saving
===============
**saving** contains altair saving functions

.. autofunction:: nesta_ds_utils.viz.altair.saving.webdriver_context
.. autofunction:: nesta_ds_utils.viz.altair.saving.save
63 changes: 41 additions & 22 deletions nesta_ds_utils/viz/altair/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,54 @@
Module containing utils for styling and exporting figures using Altair.
"""

import altair_saver as alt_saver
from altair.vegalite import Chart
import altair as alt
from selenium import webdriver
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.chrome.webdriver import WebDriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver import Chrome, ChromeOptions, ChromeService
import os
from typing import Union, List, Type
import warnings
from matplotlib import font_manager
from pathlib import Path
from nesta_ds_utils.loading_saving import file_ops
import yaml
from contextlib import contextmanager


def _google_chrome_driver_setup() -> WebDriver:
"""Set up the driver to save figures"""
chrome_options = Options()
chrome_options.add_argument("--headless")
driver = webdriver.Chrome(
ChromeDriverManager().install(), chrome_options=chrome_options
)
service = ChromeService(ChromeDriverManager().install())
chrome_options = ChromeOptions()
chrome_options.add_argument("--headless=new")
driver = Chrome(service=service, options=chrome_options)
return driver


@contextmanager
def webdriver_context(driver: WebDriver = None):
"""Context Manager for Selenium WebDrivers.
Optionally pass in user-instantiated Selenium Webdriver.
Defaults to setup and yield a ChromeWebDriver.
Typical usage:
with webdriver_context(webdriver or None) as driver:
# Do stuff with driver, driver.quit() is then called automatically
Args:
driver (WebDriver, optional): Webdriver to use. Defaults to 'webdriver.Chrome'.
Yields:
WebDriver: The optional user-instantiated Selenium WebDriver or a Selenium ChromeWebDriver.
"""
try:
driver = _google_chrome_driver_setup() if driver is None else driver
yield driver
finally:
driver.quit()


def _save_png(
fig: Chart, path: os.PathLike, name: str, scale_factor: int, driver: WebDriver
):
Expand All @@ -40,8 +62,7 @@ def _save_png(
scale_factor (int): Saving scale factor.
driver (WebDriver): webdriver to use for saving.
"""
alt_saver.save(
fig,
fig.save(
f"{path}/{name}.png",
method="selenium",
webdriver=driver,
Expand Down Expand Up @@ -73,8 +94,7 @@ def _save_svg(
scale_factor (int): Saving scale factor.
driver (WebDriver): webdriver to use for saving.
"""
alt_saver.save(
fig,
fig.save(
f"{path}/{name}.svg",
method="selenium",
scale_factor=scale_factor,
Expand Down Expand Up @@ -104,27 +124,26 @@ def save(
save_svg (bool, optional): Option to save figure as 'svg'. Default to False.
scale_factor (int, optional): Saving scale factor. Default to 5.
"""
path = file_ops._convert_str_to_pathlib_path(path)

if not any([save_png, save_html, save_svg]):
raise Exception(
"At least one format needs to be selected. Example: save(.., save_png=True)."
)

path = file_ops._convert_str_to_pathlib_path(path)
file_ops.make_path_if_not_exist(path)

if save_png or save_svg:
driver = _google_chrome_driver_setup() if driver is None else driver
with webdriver_context(driver):
# Export figures
if save_png:
_save_png(fig, path, name, scale_factor, driver)

file_ops.make_path_if_not_exist(path)
# Export figures
if save_png:
_save_png(fig, path, name, scale_factor, driver)
if save_svg:
_save_svg(fig, path, name, scale_factor, driver)

if save_html:
_save_html(fig, path, name, scale_factor)

if save_svg:
_save_svg(fig, path, name, scale_factor, driver)


def _find_averta() -> str:
"""Search for averta font, otherwise return 'Helvetica' and raise a warning.
Expand Down
52 changes: 20 additions & 32 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,39 @@ packages = ["nesta_ds_utils"]
[options]
python_requires = >=3.8
install_requires =
numpy==1.23.4
pandas==1.5.1
numpy>=1.23.4
pandas>=1.5.1
pyyaml<5.4.0
scipy==1.9.3
pyarrow==10.0.0
scipy>=1.9.3
pyarrow>=10.0.0
[options.extras_require]
s3 =
boto3==1.24.93
boto3>=1.24.93
gis =
geopandas==0.13.2
io_extras =
openpyxl==3.0.9
geopandas>=0.13.2
io_extras =
openpyxl>=3.0.9
viz =
altair==4.2.0
altair-saver==0.5.0
matplotlib==3.6.2
selenium==4.2.0
webdriver_manager==4.0.0
altair>=4.2.0
vl-convert-python>=1.2.0
matplotlib>=3.6.2
selenium>=4.2.0
webdriver_manager>=4.0.0
networks =
networkx==2.8.8
nlp =
nltk==3.7
test =
pytest==7.1.3
moto[s3]==4.0.7
nltk>=3.7
all =
%(s3)s
%(gis)s
%(io_extras)s
%(viz)s
%(networks)s
%(nlp)s
test =
pytest==7.1.3
moto[s3]==4.0.7
%(all)s
dev =
Sphinx==5.2.3
sphinxcontrib-applehelp==1.0.2
Expand All @@ -51,24 +53,10 @@ dev =
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
pytest==7.1.3
moto[s3]==4.0.7
pre-commit==2.20.0
pre-commit-hooks==4.3.0
black==22.10.0
%(s3)s
%(gis)s
%(io_extras)s
%(viz)s
%(networks)s
%(nlp)s
all =
%(s3)s
%(gis)s
%(io_extras)s
%(viz)s
%(networks)s
%(nlp)s
%(test)s
[options.package_data]
nesta_ds_utils.viz.themes =
Expand Down
16 changes: 16 additions & 0 deletions tests/viz/altair/test_saving.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil
from pathlib import Path
from nesta_ds_utils.viz.altair import saving
from selenium.webdriver.chromium.webdriver import ChromiumDriver
import pandas as pd
import altair as alt
import pytest
Expand Down Expand Up @@ -41,3 +42,18 @@ def test_save_altair_exception():
saving.save(
fig, "test_fig", path, save_png=False, save_html=False, save_svg=False
)


def test_webdriver():
"""Test that Chrome WebDriver is created by default is a ChromiumDriver, and context manager stops the webdriver."""
driver = saving._google_chrome_driver_setup()
assert isinstance(driver, ChromiumDriver)

with saving.webdriver_context(driver) as some_driver:
# No actions needed here,
# just testing that the context manager calls .quit() on driver to terminate.
pass

# If subprocess not terminated, .poll() returns None
# https://docs.python.org/3/library/subprocess.html#subprocess.Popen.returncode
assert driver.service.process.poll() is not None

0 comments on commit f19531c

Please sign in to comment.