Skip to content

Commit

Permalink
Merge pull request #260 from motiwari/fixed_loss
Browse files Browse the repository at this point in the history
Fixed Loss Mismatch Issues in BanditPAM
  • Loading branch information
Adarsh321123 committed Jul 2, 2023
2 parents ff01890 + c245147 commit 53ef393
Show file tree
Hide file tree
Showing 15 changed files with 165 additions and 164 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run_linux_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
sudo apt install -y build-essential checkinstall libncursesw5-dev libssl-dev libsqlite3-dev tk-dev libgdbm-dev libc6-dev libbz2-dev libffi-dev zlib1g-dev
sudo apt install -y clang-format cppcheck
- name: Install Python dependencies
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
Expand Down Expand Up @@ -58,7 +58,7 @@ jobs:
python -m pip install --no-use-pep517 --no-build-isolation -vvv -e .
env:
# The default compiler on the Github Ubuntu runners is gcc
# Would need to make a respective include change for clang
# Would need to make a respective include change for clang
CPLUS_INCLUDE_PATH: /usr/local/include/carma

- name: Downloading data files for tests
Expand Down
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ BanditPAM.egg-info/
__pycache__
banditpam.cpython*
build/
build/
data/MNIST_*.csv
data/MNIST_CSV.csv
data/data.csv
data/scrna*
data/*.py
data/cifar*
docs/html
docs/latex
python_generate_plots/__pycache__
Expand Down
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ project(BanditPAM VERSION 1.0 LANGUAGES CXX)

# TODO(@motiwari): Should RELEASE? be in caps
set(CMAKE_BUILD_TYPE Release)
# set(CMAKE_BUILD_TYPE Debug)
set(CMAKE_CXX_STANDARD 17)

set(CMAKE_CXX_FLAGS_RELEASE "-O3")

# Note: ThreadSanitizer and AddressSanitizer can give spurious errors around the #pragma omp critical.
# They don't really matter because the threads always resolve with the same end state
# ThreadSanitizer also introduces a very large overhead, AddressSanitizer is more reasonable
set(CMAKE_CXX_FLAGS_DEBUG "-Wall -Wextra -g -fno-omit-frame-pointer -fsanitize=address")
# set(CMAKE_CXX_FLAGS_DEBUG "-Wall -Wextra -g -fno-omit-frame-pointer -fsanitize=address")

# NOTE: Need to explicitly pass -O0 to enable printing of armadillo matrices
#set(CMAKE_CXX_FLAGS_DEBUG "-O0 -Wall -Wextra -g -fno-omit-frame-pointer")
# set(CMAKE_CXX_FLAGS_DEBUG "-O0 -Wall -Wextra -g -fno-omit-frame-pointer")

add_subdirectory(src)
43 changes: 25 additions & 18 deletions scripts/cache_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,30 @@ def test_cache_stats():
hits_1000,
misses_1000,
) = get_cache_statistics(kmed=kmed, X=X, loss="L2", cache_width=1000)
time_750, width_750, writes_750, hits_750, misses_750 = \
get_cache_statistics(
kmed=kmed, X=X, loss="L2", cache_width=750
)
time_500, width_500, writes_500, hits_500, misses_500 = \
get_cache_statistics(
kmed=kmed, X=X, loss="L2", cache_width=500
)
time_250, width_250, writes_250, hits_250, misses_250 = \
get_cache_statistics(
kmed=kmed, X=X, loss="L2", cache_width=250
)
time_0, width_0, writes_0, hits_0, misses_0 = \
get_cache_statistics(
kmed=kmed, X=X, loss="L2", cache_width=0
)
(
time_750,
width_750,
writes_750,
hits_750,
misses_750,
) = get_cache_statistics(kmed=kmed, X=X, loss="L2", cache_width=750)
(
time_500,
width_500,
writes_500,
hits_500,
misses_500,
) = get_cache_statistics(kmed=kmed, X=X, loss="L2", cache_width=500)
(
time_250,
width_250,
writes_250,
hits_250,
misses_250,
) = get_cache_statistics(kmed=kmed, X=X, loss="L2", cache_width=250)
time_0, width_0, writes_0, hits_0, misses_0 = get_cache_statistics(
kmed=kmed, X=X, loss="L2", cache_width=0
)

assert (
hits_1000 > hits_750 > hits_500 > hits_250 > hits_0
Expand All @@ -92,8 +100,7 @@ def test_cache_stats():
assert width_250 == 250, "Cache width should be 250 when set to 250"
assert width_500 == 500, "Cache width should be 500 when set to 500"
assert width_750 == 750, "Cache width should be 750 when set to 750"
assert width_1000 == 1000, "Cache width should be 1000 when set to " \
"1000"
assert width_1000 == 1000, "Cache width should be 1000 when set to " "1000"

def test_parallelization():
X = np.loadtxt(os.path.join("data", "MNIST_10k.csv"))
Expand Down
41 changes: 0 additions & 41 deletions scripts/compare_banditpam_versions.py

This file was deleted.

20 changes: 13 additions & 7 deletions scripts/comparison_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
def print_results(kmed, runtime):
from banditpam import KMedoids


def print_results(kmed: KMedoids, runtime: float):
complexity_with_caching = kmed.getDistanceComputations(True) - kmed.cache_hits
print("-----Results-----")
print("Algorithm:", kmed.algorithm)
print("Final Medoids:", kmed.medoids)
Expand All @@ -7,17 +11,19 @@ def print_results(kmed, runtime):
print("Build complexity:", f"{kmed.build_distance_computations:,}")
print("Swap complexity:", f"{kmed.swap_distance_computations:,}")
print("Number of Swaps", kmed.steps)
print(
"Average Swap Sample Complexity:",
f"{kmed.swap_distance_computations / kmed.steps:,}",
)
print("Cache Writes: {:,}".format(kmed.cache_writes))
print("Cache Hits: {:,}".format(kmed.cache_hits))
print("Cache Misses: {:,}".format(kmed.cache_misses))
print(
"Total complexity (without misc):",
f"{kmed.getDistanceComputations(False):,}"
f"{kmed.getDistanceComputations(False):,}",
)
print(
"Total complexity (with misc):",
f"{kmed.getDistanceComputations(True):,}",
)
print("Runtime per swap:", runtime / kmed.steps)
print(
"Total complexity (with caching):",
f"{complexity_with_caching:,}",
)
print("Total runtime:", runtime)
8 changes: 2 additions & 6 deletions scripts/comparison_with_fasterpam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ def benchmark(data, f, n=1):
v = np.array(v)
min, avg = v.min(), v.mean()
ste = v.std(ddof=1) / np.sqrt(len(v)) if len(v) > 1 else 0.0
print(
"{:16s} min={:-10.2f} mean={:-10.2f} ±{:-.2f}".format(
k, min, avg, ste
)
)
print("{:16s} min={:-10.2f} mean={:-10.2f} ±{:-.2f}".format(k, min, avg, ste))


def run_fasterpam(data, seed):
Expand Down Expand Up @@ -69,7 +65,7 @@ def run_old_bandit(data, seed):
n_medoids=5,
parallelize=True,
algorithm="BanditPAM_orig",
dist_mat=diss
dist_mat=diss,
)
print(km.algorithm)
km.seed = seed
Expand Down
8 changes: 0 additions & 8 deletions scripts/scaling_with_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,3 @@
print(time.time() - start, "seconds")
print("Number of SWAP steps:", kmed.steps)
print(kmed.medoids)

# for k in [5, 10, 30]:
# kmed = banditpam.KMedoids(n_medoids=k, algorithm="BanditPAM_orig")
# start = time.time()
# kmed.fit(X, "L2")
# print(time.time() - start, "seconds")
# print("Number of SWAP steps:", kmed.steps)
# print(kmed.medoids)
70 changes: 26 additions & 44 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def compiler_check():
"""
try:
return (
"clang"
if "clang" in distutils.sysconfig.get_config_vars()["CC"]
else "gcc"
"clang" if "clang" in distutils.sysconfig.get_config_vars()["CC"] else "gcc"
)
except KeyError:
# The 'CC' environment variable hasn't been set.
Expand All @@ -67,8 +65,7 @@ def compiler_check():
return "gcc"

raise Exception(
"No C++ compiler was found. Please ensure you have "
"MSVC, LLVM clang, or GCC."
"No C++ compiler was found. Please ensure you have " "MSVC, LLVM clang, or GCC."
)


Expand Down Expand Up @@ -268,8 +265,7 @@ def setup_colab(delete_source=False):
repo_location = os.path.join("/", "content", "BanditPAM")
# Note the space after the git URL to separate the source and target
os.system(
"git clone https://github.com/motiwari/BanditPAM.git "
+ repo_location
"git clone https://github.com/motiwari/BanditPAM.git " + repo_location
)
os.system(
repo_location
Expand Down Expand Up @@ -416,8 +412,10 @@ def build_extensions(self):

for ext in self.extensions:
ext.define_macros = [
("VERSION_INFO",
'"{}"'.format(self.distribution.get_version()))
(
"VERSION_INFO",
'"{}"'.format(self.distribution.get_version()),
)
]
ext.extra_compile_args = opts
ext.extra_compile_args += [] # []["-arch", "x86_64"]
Expand Down Expand Up @@ -455,9 +453,7 @@ def main():
# To include carma when the BanditPAM repo hasnt been initialized
os.path.join("/", "usr", "local", "include"),
os.path.join("/", "usr", "local", "include", "carma"),
os.path.join(
"/", "usr", "local", "include", "carma", "carma_bits"
),
os.path.join("/", "usr", "local", "include", "carma", "carma_bits"),
# When building from source on M1 Macs, may need these dirs
# Currently, we should never be building from source on an M1 Mac,
# Only cross-compiling from an Intel Mac
Expand All @@ -468,22 +464,21 @@ def main():
os.path.join("/", "opt", "homebrew", "lib"),
os.path.join("/", "opt", "homebrew", "opt"),
os.path.join("/", "opt", "homebrew", "opt", "armadillo"),
os.path.join("/", "opt", "homebrew", "opt", "armadillo", "include"),
os.path.join(
"/", "opt", "homebrew", "opt", "armadillo", "include"
),
os.path.join(
"/", "opt", "homebrew", "opt", "armadillo", "include",
"armadillo_bits"
"/",
"opt",
"homebrew",
"opt",
"armadillo",
"include",
"armadillo_bits",
),
# Needed for Mac Github Runners
# for macos-10.15
os.path.join(
"/", "usr", "local", "Cellar", "libomp", "15.0.2", "include"
),
os.path.join("/", "usr", "local", "Cellar", "libomp", "15.0.2", "include"),
# for macos-latest
os.path.join(
"/", "usr", "local", "Cellar", "libomp", "15.0.7", "include"
),
os.path.join("/", "usr", "local", "Cellar", "libomp", "15.0.7", "include"),
]
elif sys.platform == "win32": # WIN32
include_dirs = [
Expand Down Expand Up @@ -526,12 +521,8 @@ def main():
] # TODO(@motiwari): Modify this based on gcc or clang
library_dirs = [
os.path.join("/", "usr", "local", "lib"),
os.path.join(
"/", "usr", "local", "Cellar", "libomp", "15.0.2", "lib"
),
os.path.join(
"/", "usr", "local", "Cellar", "libomp", "15.0.7", "lib"
),
os.path.join("/", "usr", "local", "Cellar", "libomp", "15.0.2", "lib"),
os.path.join("/", "usr", "local", "Cellar", "libomp", "15.0.7", "lib"),
]
if sys.platform == "darwin" and platform.processor() == "arm": # M1
library_dirs.append(
Expand All @@ -550,21 +541,15 @@ def main():
os.path.join("src", "algorithms", "banditpam.cpp"),
os.path.join("src", "algorithms", "banditpam_orig.cpp"),
os.path.join("src", "algorithms", "fastpam1.cpp"),
os.path.join(
"src", "python_bindings", "kmedoids_pywrapper.cpp"
),
os.path.join("src", "python_bindings", "kmedoids_pywrapper.cpp"),
os.path.join("src", "python_bindings", "medoids_python.cpp"),
os.path.join(
"src", "python_bindings", "build_medoids_python.cpp"
),
os.path.join("src", "python_bindings", "build_medoids_python.cpp"),
os.path.join("src", "python_bindings", "fit_python.cpp"),
os.path.join("src", "python_bindings", "labels_python.cpp"),
os.path.join("src", "python_bindings", "steps_python.cpp"),
os.path.join("src", "python_bindings", "loss_python.cpp"),
os.path.join("src", "python_bindings", "cache_python.cpp"),
os.path.join(
"src", "python_bindings", "swap_times_python.cpp"
),
os.path.join("src", "python_bindings", "swap_times_python.cpp"),
],
include_dirs=include_dirs,
library_dirs=library_dirs,
Expand All @@ -586,9 +571,8 @@ def main():
[
os.path.join(
os.getcwd(),
r"headers\armadillo\examples\lib_win64"
+ r"\libopenblas.dll",
)
r"headers\armadillo\examples\lib_win64" + r"\libopenblas.dll",
)
],
)
)
Expand Down Expand Up @@ -619,9 +603,7 @@ def main():
os.path.join("headers", "algorithms", "banditpam.hpp"),
os.path.join("headers", "algorithms", "fastpam1.hpp"),
os.path.join("headers", "algorithms", "pam.hpp"),
os.path.join(
"headers", "python_bindings", "kmedoids_pywrapper.hpp"
),
os.path.join("headers", "python_bindings", "kmedoids_pywrapper.hpp"),
],
)

Expand Down

0 comments on commit 53ef393

Please sign in to comment.