Skip to content

Commit

Permalink
Merge branch 'write_input' of https://github.com/FarnazH/iodata into …
Browse files Browse the repository at this point in the history
…FarnazH-write_input
  • Loading branch information
tovrstra committed Mar 12, 2021
2 parents c8e8f10 + 0f628f8 commit 0ba9f7f
Show file tree
Hide file tree
Showing 15 changed files with 587 additions and 14 deletions.
68 changes: 64 additions & 4 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .utils import LineIterator


__all__ = ['load_one', 'load_many', 'dump_one', 'dump_many']
__all__ = ['load_one', 'load_many', 'dump_one', 'dump_many', 'write_input']


def _find_format_modules():
Expand Down Expand Up @@ -78,6 +78,40 @@ def _select_format_module(filename: str, attrname: str, fmt: str = None) -> Modu
attrname, filename))


def _find_input_modules():
"""Return all input modules found with importlib."""
result = {}
for module_info in iter_modules(import_module('iodata.inputs').__path__):
if not module_info.ispkg:
format_module = import_module('iodata.inputs.' + module_info.name)
result[module_info.name] = format_module
return result


INPUT_MODULES = _find_input_modules()


def _select_input_module(fmt: str) -> ModuleType:
"""Find an input module with the requested attribute name.
Parameters
----------
fmt
The name of the input module to use.
Returns
-------
format_module
The module implementing the required input format.
"""
if fmt in INPUT_MODULES:
if not hasattr(INPUT_MODULES[fmt], 'write_input'):
raise ValueError(f'{fmt} input module does not have write_input!')
return INPUT_MODULES[fmt]
raise ValueError(f"Could not find input format {fmt}!")


def load_one(filename: str, fmt: str = None, **kwargs) -> IOData:
"""Load data from a file.
Expand All @@ -97,7 +131,7 @@ def load_one(filename: str, fmt: str = None, **kwargs) -> IOData:
Returns
-------
out
out: IOData
The instance of IOData with data loaded from the input files.
"""
Expand Down Expand Up @@ -128,7 +162,7 @@ def load_many(filename: str, fmt: str = None, **kwargs) -> Iterator[IOData]:
Yields
------
out
out: IOData
An instance of IOData with data for one frame loaded for the file.
"""
Expand Down Expand Up @@ -177,7 +211,7 @@ def dump_many(iodatas: Iterator[IOData], filename: str, fmt: str = None, **kwarg
----------
iodatas
An iterator over IOData instances.
filename : str
filename
The file to write the data to.
fmt
The name of the file format module to use.
Expand All @@ -188,3 +222,29 @@ def dump_many(iodatas: Iterator[IOData], filename: str, fmt: str = None, **kwarg
format_module = _select_format_module(filename, 'dump_many', fmt)
with open(filename, 'w') as f:
format_module.dump_many(f, iodatas, **kwargs)


def write_input(iodata: IOData, filename: str, fmt: str, template: str = None, **kwargs):
"""Write input file using an instance of IOData for the specified software format.
Parameters
----------
iodata
An IOData instance containing the information needed to write input.
filename
The input file name.
fmt
The name of the software for which input file is generated.
template
The template input file.
**kwargs
Keyword arguments are passed on to the input-specific write_input function.
"""
input_module = _select_input_module(fmt)
# load template as a string
if template is not None:
with open(template, 'r') as t:
template = t.read()
with open(filename, 'w') as f:
input_module.write_input(f, iodata, template=template, **kwargs)
102 changes: 92 additions & 10 deletions iodata/docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from typing import List, Dict


__all__ = ['document_load_one', 'document_load_many', 'document_dump_one', 'document_dump_many']
__all__ = ['document_load_one', 'document_load_many', 'document_dump_one', 'document_dump_many',
'document_write_input']


def _document_load(template: str, fmt: str, guaranteed: List[str], ifpresent: List[str] = None,
Expand Down Expand Up @@ -64,7 +65,7 @@ def decorator(func):
{kwdocs}
Returns
-------
data
result: dict
A dictionary with IOData attributes. The following attributes are guaranteed to be
loaded: {guaranteed}.{ifpresent}
Expand Down Expand Up @@ -105,19 +106,19 @@ def document_load_one(fmt: str, guaranteed: List[str], ifpresent: List[str] = No


LOAD_MANY_DOC_TEMPLATE = """\
Load multiple frame from a {fmt} file.
Load multiple frames from a {fmt} file.
Parameters
----------
lit
The line iterator to read the data from.
{kwdocs}
Yields
------
data
result: dict
A dictionary with IOData attributes. The following attribtues are guaranteed to be
loaded: {guaranteed}.{ifpresent}
{kwdocs}
Notes
-----
Expand Down Expand Up @@ -151,7 +152,7 @@ def document_load_many(fmt: str, guaranteed: List[str], ifpresent: List[str] = N
A decorator function.
"""
return _document_load(LOAD_ONE_DOC_TEMPLATE, fmt, guaranteed, ifpresent, kwdocs, notes)
return _document_load(LOAD_MANY_DOC_TEMPLATE, fmt, guaranteed, ifpresent, kwdocs, notes)


def _document_dump(template: str, fmt: str, required: List[str], optional: List[str] = None,
Expand All @@ -172,7 +173,7 @@ def decorator(func):
optional=optional_sentence,
kwdocs="\n".join("{}\n {}".format(name, docu.replace("\n", " "))
for name, docu in sorted(kwdocs.items())),
notes=notes,
notes=(notes or ""),
)
func.fmt = fmt
func.required = required
Expand Down Expand Up @@ -237,10 +238,12 @@ def document_dump_one(fmt: str, required: List[str], optional: List[str] = None,
----------
f
A writeable file object.
data
An IOData instance which must have the following attributes initialized:
datas
An iterator over IOData instances which must have the following attributes initialized:
{required}.{optional}
{kwdocs}
Notes
-----
{notes}
Expand Down Expand Up @@ -273,3 +276,82 @@ def document_dump_many(fmt: str, required: List[str], optional: List[str] = None
"""
return _document_dump(DUMP_MANY_DOC_TEMPLATE, fmt, required, optional, kwdocs, notes)


def _document_write(template: str, fmt: str, required: List[str], optional: List[str] = None,
kwdocs: Dict[str, str] = {}, notes: str = None):
optional = optional or []

def decorator(func):
if optional:
optional_sentence = (
" If the following attributes are present, they are also written "
"into the file: {}. If these attributes are not assigned, "
"internal default values are used."
).format(', '.join("``{}``".format(word) for word in optional))
else:
optional_sentence = ""
func.__doc__ = template.format(
fmt=fmt,
required=', '.join("``{}``".format(word) for word in required),
optional=optional_sentence,
kwdocs="\n".join("{}\n {}".format(name, docu.replace("\n", " "))
for name, docu in sorted(kwdocs.items())),
notes=(notes or ""),
)
func.fmt = fmt
func.required = required
func.optional = optional
func.kwdocs = kwdocs
func.notes = notes
return func
return decorator


WRITE_INPUT_DOC_TEMPLATE = """\
Write a {fmt} input file.
Parameters
----------
f
A writeable file object.
data
An IOData instance which must have the following attributes initialized:
{required}.{optional}
template
A template input file.
{kwdocs}
Notes
-----
{notes}
"""


def document_write_input(fmt: str, required: List[str], optional: List[str] = None,
kwdocs: Dict[str, str] = {}, notes: str = None):
"""Decorate a write_input function to generate a docstring.
Parameters
----------
fmt
The name of the file format.
required
A list of mandatory IOData attributes needed to write the file.
optional
A list of optional IOData attributes which can be include when writing the file.
kwdocs
A dictionary with documentation for keyword arguments. Each key is a
keyword argument name and the corresponding value is text explaining the
argument.
notes
Additional information to be added to the docstring.
Returns
-------
decorator
A decorator function.
"""
return _document_write(WRITE_INPUT_DOC_TEMPLATE, fmt, required, optional, kwdocs, notes)
18 changes: 18 additions & 0 deletions iodata/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
43 changes: 43 additions & 0 deletions iodata/inputs/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Utilities for writing input files."""

import attr

import numpy as np

from ..iodata import IOData
from ..utils import angstrom

__all__ = ['populate_fields']


def populate_fields(data: IOData) -> dict:
"""Generate a dictionary with fields to replace in the template."""
# load IOData dict using attr.asdict because the IOData class uses __slots__
fields = attr.asdict(data, recurse=False)
# store atomic coordinates in angstrom
fields["atcoords"] = data.atcoords / angstrom
# set general defaults
fields["title"] = data.title if data.title is not None else 'Input Generated by IOData'
fields["run_type"] = data.run_type if data.run_type is not None else 'energy'
# convert spin polarization to multiplicity
fields["spinmult"] = int(abs(np.round(data.spinpol))) + 1 if data.spinpol is not None else 1
fields["charge"] = int(data.charge) if data.charge is not None else 0
return fields
68 changes: 68 additions & 0 deletions iodata/inputs/gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Gaussian Input Module."""


from typing import TextIO
from string import Template

from .common import populate_fields

from ..docstrings import document_write_input
from ..iodata import IOData
from ..periodic import num2sym


__all__ = []


default_template = """\
#n ${lot}/${obasis_name} ${run_type}
${title}
${charge} ${spinmult}
${geometry}
"""


@document_write_input("GAUSSIAN", ['atnums', 'atcoords'],
['title', 'run_type', 'lot', 'obasis_name', 'spinmult', 'charge'])
def write_input(f: TextIO, data: IOData, template: str = None):
"""Do not edit this docstring. It will be overwritten."""
# initialize a dictionary with fields to replace in the template
fields = populate_fields(data)
# set format-specific defaults
fields["lot"] = data.lot if data.lot is not None else 'hf'
fields["obasis_name"] = data.obasis_name if data.obasis_name is not None else 'sto-3g'
# convert run type to Gaussian keywords
run_types = {"energy": "sp", "energy_force": "force", "opt": "opt", "scan": "scan",
"freq": "freq"}
fields["run_type"] = run_types[fields["run_type"].lower()]
# generate geometry (in angstrom)
geometry = []
for num, coord in zip(fields["atnums"], fields["atcoords"]):
geometry.append(f"{num2sym[num]:3} {coord[0]:10.6f} {coord[1]:10.6f} {coord[2]:10.6f}")
fields["geometry"] = "\n".join(geometry)
# get template
if template is None:
template = default_template
# populate files & write input
print(Template(template).substitute(fields), file=f)

0 comments on commit 0ba9f7f

Please sign in to comment.