Skip to content
/ mkb Public

Knowledge Base Embedding By Cooperative Knowledge Distillation

Notifications You must be signed in to change notification settings

raphaelsty/mkb

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

PyTorch

PyTorch

mkb is a library dedicated to knowledge graph embeddings. The purpose of this library is to provide modular tools using PyTorch.


Table of contents

πŸ‘Ύ Installation

You should be able to install and use this library with any Python version above 3.6.

pip install git+https://github.com/raphaelsty/mkb

⚑️ Quickstart:

Load or initialize your dataset as a list of triplets:

train = [
    ('πŸ¦†', 'is a', 'bird'),
    ('πŸ¦…', 'is a', 'bird'),

    ('πŸ¦†', 'lives in', '🌳'),
    ('πŸ¦‰', 'lives in', '🌳'),
    ('πŸ¦…', 'lives in', 'πŸ”'),

    ('πŸ¦‰', 'hability', 'fly'),
    ('πŸ¦…', 'hability', 'fly'),

    ('🐌', 'is a', 'mollusc'),
    ('🐜', 'is a', 'insect'),
    ('🐝', 'is a', 'insect'),

    ('🐌', 'lives in', '🌳'),
    ('🐝', 'lives in', '🌳'),

    ('🐝', 'hability', 'fly'),

    ('🐻', 'is a', 'mammal'),
    ('🐢', 'is a', 'mammal'),
    ('🐨', 'is a', 'mammal'),

    ('🐻', 'lives in', 'πŸ”'),
    ('🐢', 'lives in', '🏠'),
    ('🐱', 'lives in', '🏠'),
    ('🐨', 'lives in', '🌳'),

    ('🐬', 'lives in', '🌊'),
    ('🐳', 'lives in', '🌊'),

    ('πŸ‹', 'is a', 'marine mammal'),
    ('🐳', 'is a', 'marine mammal'),
]

valid = [
    ('πŸ¦†', 'hability', 'fly'),
    ('🐱', 'is_a', 'mammal'),
    ('🐜', 'lives_in', '🌳'),
    ('🐬', 'is_a', 'marine mammal'),
    ('πŸ‹', 'lives_in', '🌊'),
    ('πŸ¦‰', 'is a', 'bird'),
]

Train your model to make coherent embeddings for each entities and relations of your dataset using a pipeline:

from mkb import datasets
from mkb import models
from mkb import losses
from mkb import sampling
from mkb import evaluation
from mkb import compose

import torch

_ = torch.manual_seed(42)

# Set device = 'cuda' if you own a gpu.
device = 'cpu'

dataset = datasets.Dataset(
    train      = train,
    valid      = valid,
    batch_size = 24,
)

model = models.RotatE(
    entities   = dataset.entities,
    relations  = dataset.relations,
    gamma      = 3,
    hidden_dim = 200,
)

model = model.to(device)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr = 0.003,
)

negative_sampling = sampling.NegativeSampling(
    size          = 24,
    train_triples = dataset.train,
    entities      = dataset.entities,
    relations     = dataset.relations,
    seed          = 42,
)

validation = evaluation.Evaluation(
    true_triples = dataset.true_triples,
    entities     = dataset.entities,
    relations    = dataset.relations,
    batch_size   = 8,
    device       = device,
)

pipeline = compose.Pipeline(
    epochs                = 100,
    eval_every            = 50,
    early_stopping_rounds = 3,
    device                = device,
)

pipeline = pipeline.learn(
    model      = model,
    dataset    = dataset,
    evaluation = validation,
    sampling   = negative_sampling,
    optimizer  = optimizer,
    loss       = losses.Adversarial(alpha=1)
)

Plot embeddings:
from sklearn import manifold

from sklearn import cluster

import seaborn as sns

import matplotlib.pyplot as plt

import pandas as pd

emojis_tokens = {
    'πŸ¦†': 'duck',
    'πŸ¦…': 'eagle',
    'πŸ¦‰': 'owl',
    '🐌': 'snail',
    '🐜': 'ant',
    '🐝': 'bee',
    '🐻': 'bear',
    '🐢': 'dog',
    '🐨': 'koala',
    '🐱': 'cat',
    '🐬': 'dolphin',
    '🐳': 'whale',
    'πŸ‹': 'humpback whale',
}

embeddings = pd.DataFrame(model.embeddings['entities']).T.reset_index()

embeddings = embeddings[embeddings['index'].isin(emojis_tokens)].set_index('index')

tsne = manifold.TSNE(n_components = 2, random_state = 42, n_iter=1500, perplexity=3, early_exaggeration=100)

kmeans = cluster.KMeans(n_clusters = 5, random_state=42)

X = tsne.fit_transform(embeddings)
X = pd.DataFrame(X, columns = ['dim_1', 'dim_2'])
X['cluster'] = kmeans.fit_predict(X)

%config InlineBackend.figure_format = 'retina'

fgrid = sns.lmplot(
    data = X,
    x = 'dim_1',
    y = 'dim_2',
    hue = 'cluster',
    fit_reg = False,
    legend = False,
    legend_out = False,
    height = 7,
    aspect = 1.6,
    scatter_kws={"s": 500}
)

ax = fgrid.axes[0,0]
ax.set_ylabel('')
ax.set_xlabel('')
ax.set(xticklabels=[])
ax.set(yticklabels=[])

for i, label in enumerate(embeddings.index):

     ax.text(
         X['dim_1'][i] + 1,
         X['dim_2'][i],
         emojis_tokens[label],
         horizontalalignment = 'left',
         size = 'medium',
         color = 'black',
         weight = 'semibold',
     )

plt.show()
PyTorch

πŸ—‚ Datasets

Datasets available:

  • datasets.CountriesS1
  • datasets.CountriesS2
  • datasets.CountriesS3
  • datasets.Fb13
  • datasets.Fb15k
  • datasets.Fb15k237
  • datasets.InferWiki16k
  • datasets.InferWiki64k
  • datasets.Kinship
  • datasets.Nations
  • datasets.Nell995
  • datasets.Umls
  • datasets.Wn11
  • datasets.Wn18
  • datasets.Wn18rr
  • datasets.Yago310

Load existing dataset:

from mkb import datasets

dataset = datasets.Wn18rr(batch_size=256)

dataset
Wn18rr dataset
    Batch size          256
    Entities            40923
    Relations           11
    Shuffle             True
    Train triples       86834
    Validation triples  3033
    Test triples        3134

Or create your own dataset:

from mkb import datasets

train = [
    ('πŸ¦†', 'is a', 'bird'),
    ('πŸ¦…', 'is a', 'bird'),
    ('πŸ¦‰', 'hability', 'fly'),
    ('πŸ¦…', 'hability', 'fly')
]

valid = [
    ('πŸ¦‰', 'is a', 'bird')
]

test = [
    ('πŸ¦†', 'hability', 'fly')
]

dataset = datasets.Dataset(
    train = train,
    valid = valid,
    test = test,
    batch_size = 3,
    seed = 42,
)

dataset
Dataset dataset
        Batch size  3
          Entities  5
         Relations  2
           Shuffle  True
     Train triples  4
Validation triples  1
      Test triples  1

πŸ€– Models

Knowledge graph models build latent representations of nodes (entities) and relationships in the graph. These models implement an optimization process to represent the entities and relations in a consistent space.

Models available:

  • models.TransE
  • models.DistMult
  • models.RotatE
  • models.pRotatE
  • models.ComplEx
  • models.SentenceTransformer
  • models.Transformer

Initialize a model:

from mkb import models

model = models.RotatE(
   entities   = dataset.entities,
   relations  = dataset.relations,
   gamma      = 6,
   hidden_dim = 500
)

model
RotatE model
    Entities embeddings  dim  1000
    Relations embeddings dim  500
    Gamma                     3.0
    Number of entities        40923
    Number of relations       11

Set the learning rate of the model:

import torch

learning_rate = 0.00005

optimizer = torch.optim.Adam(
   filter(lambda p: p.requires_grad, model.parameters()),
   lr = learning_rate,
)

🎭 Negative sampling

Knowledge graph embedding models learn to distinguish existing triplets from generated triplets. The sampling module allows to generate triplets that do not exist in the dataset.

from mkb import sampling

negative_sampling = sampling.NegativeSampling(
    size          = 256,
    train_triples = dataset.train,
    entities      = dataset.entities,
    relations     = dataset.relations,
    seed          = 42,
)

πŸ€– Train your model

You can train your model using a pipeline:

from mkb import compose
from mkb import losses
from mkb import evaluation

validation = evaluation.Evaluation(
    true_triples = dataset.true_triples,
    entities   = dataset.entities,
    relations  = dataset.relations,
    batch_size = 8,
    device     = device,
)

pipeline = compose.Pipeline(
    epochs                = 100,
    eval_every            = 50,
    early_stopping_rounds = 3,
    device                = device,
)

pipeline = pipeline.learn(
    model      = model,
    dataset    = dataset,
    evaluation = validation,
    sampling   = negative_sampling,
    optimizer  = optimizer,
    loss       = losses.Adversarial(alpha=1)
)

You can also train your model with a lower level of abstraction:

from mkb import losses
from mkb import evaluation

validation = evaluation.Evaluation(
    true_triples = dataset.true_triples,
    entities   = dataset.entities,
    relations  = dataset.relations,
    batch_size = 8,
    device     = device,
)

loss = losses.Adversarial(alpha=0.5)

for epoch in range(2000):

    for data in dataset:

        sample = data['sample'].to(device)
        weight = data['weight'].to(device)
        mode = data['mode']

        negative_sample = negative_sampling.generate(sample=sample, mode=mode)

        negative_sample = negative_sample.to(device)

        positive_score = model(sample)

        negative_score = model(
            sample=sample,
            negative_sample=negative_sample,
            mode=mode
        )

        error = loss(positive_score, negative_score, weight)

        error.backward()

        _ = optimizer.step()

        optimizer.zero_grad()

    validation_scores = validation.eval(dataset=dataset.valid, model=model)

    print(validation_scores)

πŸ“Š Evaluation

You can evaluate the performance of your models with the evaluation module.

from mkb import evaluation

validation = evaluation.Evaluation(
    true_triples = dataset.true_triples,
    entities   = dataset.entities,
    relations  = dataset.relations,
    batch_size = 8,
    device     = device,
)

🎯 Link prediction task:

The task of link prediction aim at finding the most likely head or tail for a given tuple. For example, the model should retrieve the entity United States for the triplet ('Barack Obama', 'president_of', ?).

Validate the model on the validation set:

validation.eval(model = model, dataset = dataset.valid)
{'MRR': 0.5833, 'MR': 400.0, 'HITS@1': 20.25, 'HITS@3': 30.0, 'HITS@10': 40.0}

Validate the model on the test set:

validation.eval(model = model, dataset = dataset.test)
{'MRR': 0.5833, 'MR': 600.0, 'HITS@1': 21.35, 'HITS@3': 38.0, 'HITS@10': 41.0}

πŸ”Ž Link prediction detailed evaluation:

You can get a more detailed evaluation of the link prediction task and measure the performance of the model according to the type of relationship.

validation.detail_eval(model=model, dataset=dataset.test, threshold=1.5)
          head                               tail
          MRR   MR HITS@1 HITS@3 HITS@10     MRR   MR HITS@1 HITS@3 HITS@10
relation
1_1       0.5  2.0    0.0    1.0     1.0  0.3333  3.0    0.0    1.0     1.0
1_M       1.0  1.0    1.0    1.0     1.0  0.5000  2.0    0.0    1.0     1.0
M_1       0.0  0.0    0.0    0.0     0.0  0.0000  0.0    0.0    0.0     0.0
M_M       0.0  0.0    0.0    0.0     0.0  0.0000  0.0    0.0    0.0     0.0

➑️ Relation prediction:

The task of relation prediction is to find the most likely relation for a given tuple (head, tail).

validation.eval_relations(model=model, dataset=dataset.test)
{'MRR_relations': 1.0, 'MR_relations': 1.0, 'HITS@1_relations': 1.0, 'HITS@3_relations': 1.0, 'HITS@10_relations': 1.0}

🦾 Triplet classification

The triplet classification task is designed to predict whether or not a triplet exists. The triplet classification task is available for every datasets in mkb except Countries datasets.

from mkb import evaluation

evaluation.find_threshold(
    model = model,
    X = dataset.classification_valid['X'],
    y = dataset.classification_valid['y'],
    batch_size = 10,
)

Best threshold found from triplet classification valid set and associated accuracy:

{'threshold': 1.924787, 'accuracy': 0.803803}
evaluation.accuracy(
    model = model,
    X = dataset.classification_test['X'],
    y = dataset.classification_test['y'],
    threshold = 1.924787,
    batch_size = 10,
)

Accuracy of the model on the triplet classification test set:

0.793803

🀩 Get embeddings

You can extract embeddings from entities and relationships computed by the model with the models.embeddings property.

model.embeddings['entities']
{'hello': tensor([ 0.7645,  0.8300, -0.2343]), 'world': tensor([ 0.9186, -0.2191,  0.2018])}
model.embeddings['relations']
{'lorem': tensor([-0.4869,  0.5873,  0.8815]), 'ipsum': tensor([-0.7336,  0.8692,  0.1872])}

πŸ” Transformers

MKB provides an implementation of the paper Inductive Entity Representations from Text via Link Prediction. It allows to train transformers to build embeddings of the entities of a knowledge graph under the link prediction objective. After fine-tuning the transformer on the link prediction task, we can use it to build an entity search engine. It can also perform tasks related to the completion of knowledge graphs. Finally, we can use it for any downstream task such as classification.

Using a transformer instead of embeddings has many advantages, such as constructing contextualized latent representations of entities. In addition, this model can encode entities that it has never seen with the textual description of the entity. The learning time is much longer than a classical TransE model, but the model converges with fewer epochs.

MKB provides two classes dedicated to fine-tune both Sentence Transformers and vanilla Transformers.

  • models.SentenceTransformer: Dedicated to Sentence Transformer models.
  • models.Transformer: Dedicated to traditional Transformer models.

Under the hood, the Transformer model is trained using entity labels. Therefore, it is important to provide relevant entity labels. We initialize an embedding matrix dedicated to relationships. The negative samples are generated following the in-batch strategy.

Here is how to fine-tune a sentence transformer under the link prediction objective:

from mkb import losses, evaluation, datasets, text, models
from transformers import AutoTokenizer, AutoModel
import torch

_ = torch.manual_seed(42)

train = [
    ("jaguar", "cousin", "cat"),
    ("tiger", "cousin", "cat"),
    ("dog", "cousin", "wolf"),
    ("dog", "angry_against", "cat"),
    ("wolf", "angry_against", "jaguar"),
]

valid = [
    ("cat", "cousin", "jaguar"),
    ("cat", "cousin", "tiger"),
    ("dog", "angry_against", "tiger"),
]

test = [
    ("wolf", "angry_against", "tiger"),
    ("wolf", "angry_against", "cat"),
]

dataset = datasets.Dataset(
    batch_size = 5,
    train = train,
    valid = valid,
    test = test,
    seed = 42,
    shuffle=True,
)

device = "cpu"

model = models.SentenceTransformer(
    model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2"),
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2"),
    entities = dataset.entities,
    relations = dataset.relations,
    gamma = 9,
    device = device,
)

model = model.to(device)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr = 0.000005,
)

# Link prediction evaluation for Transformers
evaluation = evaluation.TransformerEvaluation(
    entities = dataset.entities,
    relations = dataset.relations,
    true_triples = dataset.train + dataset.valid + dataset.test,
    batch_size = 2,
    device = device,
)

model = text.learn(
    model = model,
    dataset = dataset,
    evaluation = evaluation,
    optimizer = optimizer,
    loss = losses.Adversarial(alpha=0.5),
    negative_sampling_size = 5,
    epochs = 1,
    eval_every = 5,
    early_stopping_rounds = 3,
    device = device,
)

# Saving the Sentence Transformer model:
model.model.save_pretrained("model")
model.tokenizer.save_pretrained("model")

relations = {}
for id_relation, label in model.relations.items():
    relations[label] = model.relation_embedding[id_relation].cpu().detach().tolist()

with open(f"relations.json", "w") as f:
    json.dump(relations, f, indent=4)

After training a Sentence Transformer on the link prediction task using MKB and saving the model, we can load the trained model using the sentence_transformers library.

from sentence_transformers import SentenceTransformer
import json
import numpy as np

# Entity encoder
model = SentenceTransformer("model", device="cpu")

# Relations embeddings
with open(f"relations.json", "r") as f:
    relations = json.load(f)

Here is how to fine-tune a Transformer under the link prediction objective:

from mkb import losses, evaluation, datasets, text, models
from transformers import AutoTokenizer, AutoModel

model = models.Transformer(
    model = AutoModel.from_pretrained("bert-base-uncased"),
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased"),
    entities = dataset.entities,
    relations = dataset.relations,
    gamma = 9,
    device = device,
)

🧰 Development

# Download and navigate to the source code
$ git clone https://github.com/raphaelsty/mkb
$ cd mkb

# Create a virtual environment
$ python3 -m venv env
$ source env/bin/activate

# Install
$ python setup.py install

# Run tests
$ python -m pytest

πŸ’¬ Citations

Knowledge Base Embedding By Cooperative Knowledge Distillation

@inproceedings{sourty-etal-2020-knowledge,
    title = "Knowledge Base Embedding By Cooperative Knowledge Distillation",
    author = {Sourty, Rapha{\"e}l  and
      Moreno, Jose G.  and
      Servant, Fran{\c{c}}ois-Paul  and
      Tamine-Lechani, Lynda},
    booktitle = "Proceedings of the 28th International Conference on Computational Linguistics",
    month = dec,
    year = "2020",
    address = "Barcelona, Spain (Online)",
    publisher = "International Committee on Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.coling-main.489",
    pages = "5579--5590",
}

πŸ‘ See also

There are a multitude of tools and libraries available on github to build latent knowledge graph embeddings. These libraries are very complete and provide powerful implementations of knowledge graph embedding algorithms.

From a user's point of view, I find that most of the libraries suffer from a lack of modularity. That's why I created this tool. Mkb addresses this modularity problem and is easily integrated into a machine learning pipeline.

  • DGL-KE: High performance, easy-to-use, and scalable package for learning large-scale knowledge graph embeddings.

  • OpenKE: An Open-source Framework for Knowledge Embedding implemented with PyTorch.

  • GraphVite: GraphVite is a general graph embedding engine, dedicated to high-speed and large-scale embedding learning in various applications.

  • LibKGE: LibKGE is a PyTorch-based library for efficient training, evaluation, and hyperparameter optimization of knowledge graph embeddings (KGE).

  • TorchKGE: Knowledge Graph embedding in Python and Pytorch.

  • KnowledgeGraphEmbedding: RotatE, Knowledge Graph Embedding by Relational Rotation in Complex Space

πŸ—’ License

This project is free and open-source software licensed under the MIT license.