-
Notifications
You must be signed in to change notification settings - Fork 1
/
util_funcs.py
119 lines (103 loc) 路 3.94 KB
/
util_funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import random
import numpy as np
class GeneratorData(object):
def __init__(self, training_data_path, tokens=None, start_token='<',
end_token='>', max_len=120, use_cuda=None, **kwargs):
super(GeneratorData, self).__init__()
if 'cols_to_read' not in kwargs:
kwargs['cols_to_read'] = []
data = read_object_property_file(training_data_path,
**kwargs)
self.start_token = start_token
self.end_token = end_token
self.file = []
for i in range(len(data)):
if len(data[i]) <= max_len:
self.file.append(self.start_token + data[i] + self.end_token)
self.file_len = len(self.file)
self.all_characters, self.char2idx, \
self.n_characters = tokenize(self.file, tokens)
self.use_cuda = use_cuda
if self.use_cuda is None:
self.use_cuda = torch.cuda.is_available()
def load_dictionary(self, tokens, char2idx):
self.all_characters = tokens
self.char2idx = char2idx
self.n_characters = len(tokens)
def random_chunk(self):
index = random.randint(0, self.file_len-1)
return self.file[index]
def char_tensor(self, string):
tensor = torch.zeros(len(string)).long()
for c in range(len(string)):
tensor[c] = self.all_characters.index(string[c])
if self.use_cuda:
return torch.tensor(tensor).cuda()
else:
return torch.tensor(tensor)
def random_training_set(self, smiles_augmentation):
chunk = self.random_chunk()
if smiles_augmentation is not None:
chunk = '<' + smiles_augmentation.randomize_smiles(chunk[1:-1]) + '>'
inp = self.char_tensor(chunk[:-1])
target = self.char_tensor(chunk[1:])
return inp, target
def read_sdf_file(self, path, fields_to_read):
raise NotImplementedError
def update_data(self, path):
self.file, success = read_smi_file(path, unique=True)
self.file_len = len(self.file)
assert success
def read_object_property_file(path, delimiter=',', cols_to_read=[0, 1],
keep_header=False):
f = open(path, 'r')
reader = csv.reader(f, delimiter=delimiter)
data_full = np.array(list(reader))
if keep_header:
start_position = 0
else:
start_position = 1
assert len(data_full) > start_position
data = [[] for _ in range(len(cols_to_read))]
for i in range(len(cols_to_read)):
col = cols_to_read[i]
data[i] = data_full[start_position:, col]
f.close()
if len(cols_to_read) == 1:
data = data[0]
return data
def tokenize(smiles, tokens=None):
"""
Returns list of unique tokens, token-2-index dictionary and number of
unique tokens from the list of SMILES
Parameters
----------
smiles: list
list of SMILES strings to tokenize.
tokens: list, str (default None)
list of unique tokens
Returns
-------
tokens: list
list of unique tokens/SMILES alphabet.
token2idx: dict
dictionary mapping token to its index.
num_tokens: int
number of unique tokens.
"""
if tokens is None:
tokens = list(set(''.join(smiles)))
tokens = list(np.sort(tokens))
tokens = ''.join(tokens)
token2idx = dict((token, i) for i, token in enumerate(tokens))
num_tokens = len(tokens)
return tokens, token2idx, num_tokens
def estimate_and_update(generator, n_to_generate, **kwargs):
generated = []
pbar = tqdm(range(n_to_generate))
for i in pbar:
pbar.set_description("Generating molecules...")
generated.append(generator.evaluate(gen_data, predict_len=120)[1:-1])
unique_smiles = list(np.unique(generated))[1:]
return unique_smiles