Skip to content

Commit

Permalink
Merge pull request #575 from jmmshn/mongoclient_kwarg
Browse files Browse the repository at this point in the history
mongoclient_kwargs
  • Loading branch information
munrojm committed Mar 5, 2022
2 parents 11e7168 + 515c74f commit dcf3273
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/maggma/stores/compound_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
password: str = "",
main: Optional[str] = None,
merge_at_root: bool = False,
mongoclient_kwargs: Optional[Dict] = None,
**kwargs,
):
"""
Expand All @@ -50,7 +51,9 @@ def __init__(
self._coll = None # type: Any
self.main = main or collection_names[0]
self.merge_at_root = merge_at_root
self.mongoclient_kwargs = mongoclient_kwargs or {}
self.kwargs = kwargs

super(JointStore, self).__init__(**kwargs)

@property
Expand All @@ -75,9 +78,10 @@ def connect(self, force_reset: bool = False):
port=self.port,
username=self.username,
password=self.password,
**self.mongoclient_kwargs,
)
if self.username != ""
else MongoClient(self.host, self.port)
else MongoClient(self.host, self.port, **self.mongoclient_kwargs)
)
db = conn[self.database]
self._coll = db[self.main]
Expand Down
7 changes: 5 additions & 2 deletions src/maggma/stores/gridfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
ensure_metadata: bool = False,
searchable_fields: List[str] = None,
auth_source: Optional[str] = None,
mongoclient_kwargs: Optional[Dict] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
if auth_source is None:
auth_source = self.database
self.auth_source = auth_source
self.mongoclient_kwargs = mongoclient_kwargs or {}

if "key" not in kwargs:
kwargs["key"] = "_id"
Expand Down Expand Up @@ -131,9 +133,10 @@ def connect(self, force_reset: bool = False):
username=self.username,
password=self.password,
authSource=self.auth_source,
**self.mongoclient_kwargs,
)
if self.username != ""
else MongoClient(self.host, self.port)
else MongoClient(self.host, self.port, **self.mongoclient_kwargs)
)
if not self._coll or force_reset:
db = conn[self.database]
Expand Down Expand Up @@ -506,7 +509,7 @@ def connect(self, force_reset: bool = False):
Connect to the source data
"""
if not self._coll or force_reset: # pragma: no cover
conn = MongoClient(self.uri)
conn = MongoClient(self.uri, **self.mongoclient_kwargs)
db = conn[self.database]
self._coll = gridfs.GridFS(db, self.collection_name)
self._files_collection = db["{}.files".format(self.collection_name)]
Expand Down
7 changes: 5 additions & 2 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
ssh_tunnel: Optional[SSHTunnel] = None,
safe_update: bool = False,
auth_source: Optional[str] = None,
mongoclient_kwargs: Optional[Dict] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
if auth_source is None:
auth_source = self.database
self.auth_source = auth_source
self.mongoclient_kwargs = mongoclient_kwargs or {}

super().__init__(**kwargs)

Expand Down Expand Up @@ -178,9 +180,10 @@ def connect(self, force_reset: bool = False):
username=self.username,
password=self.password,
authSource=self.auth_source,
**self.mongoclient_kwargs,
)
if self.username != ""
else MongoClient(host, port)
else MongoClient(host, port, **self.mongoclient_kwargs)
)
db = conn[self.database]
self._coll = db[self.collection_name]
Expand Down Expand Up @@ -564,7 +567,7 @@ def connect(self, force_reset: bool = False):
Connect to the source data
"""
if self._coll is None or force_reset: # pragma: no cover
conn = MongoClient(self.uri)
conn = MongoClient(self.uri, **self.mongodb_client_kwargs)
db = conn[self.database]
self._coll = db[self.collection_name]

Expand Down

0 comments on commit dcf3273

Please sign in to comment.