Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] fixing TGN example #6619

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
115 changes: 110 additions & 5 deletions examples/pytorch/tgn/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,111 @@
Temporal Graph Neural Network (TGN)
===
# Temporal Graph Neural Network (TGN)

## DGL Implementation of tgn paper.

This DGL examples implements the GNN mode proposed in the paper [TemporalGraphNeuralNetwork](https://arxiv.org/abs/2006.10637.pdf)

## TGN implementor

This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his SDE internship at the AWS Shanghai AI Lab.

## Graph Dataset

Jodie Wikipedia Temporal dataset. Dataset summary:

- Num Nodes: 9227
- Num Edges: 157, 474
- Num Edge Features: 172
- Edge Feature type: LIWC
- Time Span: 30 days
- Chronological Split: Train: 70% Valid: 15% Test: 15%

Jodie Reddit Temporal dataset. Dataset summary:

- Num Nodes: 11,000
- Num Edges: 672, 447
- Num Edge Features: 172
- Edge Feature type: LIWC
- Time Span: 30 days
- Chronological Split: Train: 70% Valid: 15% Test: 15%

## How to run example files

In tgn folder, run

**please use `train.py`**

```python
python train.py --dataset wikipedia
```

If you want to run in fast mode:

```python
python train.py --dataset wikipedia --fast_mode
```

If you want to run in simple mode:

```python
python train.py --dataset wikipedia --simple_mode
```

If you want to change memory updating module:

```python
python train.py --dataset wikipedia --memory_updater [rnn/gru]
```

If you want to use TGAT:

```python
python train.py --dataset wikipedia --not_use_memory --k_hop 2
```

## Performance

#### Without New Node in test set

| Models/Datasets | Wikipedia | Reddit |
| --------------- | ------------------ | ---------------- |
| TGN simple mode | AP: 98.5 AUC: 98.9 | AP: N/A AUC: N/A |
| TGN fast mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A |
| TGN | AP: 98.9 AUC: 98.5 | AP: N/A AUC: N/A |

#### With New Node in test set

| Models/Datasets | Wikipedia | Reddit |
| --------------- | ------------------- | ---------------- |
| TGN simple mode | AP: 98.2 AUC: 98.6 | AP: N/A AUC: N/A |
| TGN fast mode | AP: 98.0 AUC: 98.4 | AP: N/A AUC: N/A |
| TGN | AP: 98.2 AUC: 98.1 | AP: N/A AUC: N/A |

## Training Speed / Batch
Intel E5 2cores, Tesla K80, Wikipedia Dataset

| Models/Datasets | Wikipedia | Reddit |
| --------------- | --------- | -------- |
| TGN simple mode | 0.3s | N/A |
| TGN fast mode | 0.28s | N/A |
| TGN | 1.3s | N/A |

### Details explained

**What is Simple Mode**

Simple Temporal Sampler just choose the edges that happen before the current timestamp and build the subgraph of the corresponding nodes.
And then the simple sampler uses the static graph neighborhood sampling methods.

**What is Fast Mode**

Normally temporal encoding needs each node to use incoming time frame as current time which might lead to two nodes have multiple interactions within the same batch need to maintain multiple embedding features which slow down the batching process to avoid feature duplication, fast mode enables fast batching since it uses last memory update time in the last batch as temporal encoding benchmark for each node. Also within each batch, all interaction between two nodes are predicted using the same set of embedding feature

**What is New Node test**

To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training.

**Why the attention module is not exactly same as TGN original paper**

Attention module used in this model is adapted from DGL GATConv, considering edge feature and time encoding. It is more memory efficient and faster to compute then the attention module proposed in the paper, meanwhile, according to our test, the accuracy of our module compared with the one in paper is the same.


The example was temporarily removed due to the change in the `DataLoader`
interface in DGL 1.0. Please refer to the v0.9 example
[here](https://github.com/dmlc/dgl/tree/0.9.x/examples/pytorch/tgn).
130 changes: 130 additions & 0 deletions examples/pytorch/tgn/data_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import ssl
from six.moves import urllib

import pandas as pd
import numpy as np

import torch
import dgl

# === Below data preprocessing code are based on
# https://github.com/twitter-research/tgn

# Preprocess the raw data split each features

def preprocess(data_name):
u_list, i_list, ts_list, label_list = [], [], [], []
feat_l = []
idx_list = []

with open(data_name) as f:
s = next(f)
for idx, line in enumerate(f):
e = line.strip().split(',')
u = int(e[0])
i = int(e[1])

ts = float(e[2])
label = float(e[3]) # int(e[3])

feat = np.array([float(x) for x in e[4:]])

u_list.append(u)
i_list.append(i)
ts_list.append(ts)
label_list.append(label)
idx_list.append(idx)

feat_l.append(feat)
return pd.DataFrame({'u': u_list,
'i': i_list,
'ts': ts_list,
'label': label_list,
'idx': idx_list}), np.array(feat_l)

# Re index nodes for DGL convience
def reindex(df, bipartite=True):
new_df = df.copy()
if bipartite:
assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))

upper_u = df.u.max() + 1
new_i = df.i + upper_u

new_df.i = new_i
new_df.u += 1
new_df.i += 1
new_df.idx += 1
else:
new_df.u += 1
new_df.i += 1
new_df.idx += 1

return new_df

# Save edge list, features in different file for data easy process data
def run(data_name, bipartite=True):
PATH = './data/{}.csv'.format(data_name)
OUT_DF = './data/ml_{}.csv'.format(data_name)
OUT_FEAT = './data/ml_{}.npy'.format(data_name)
OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)

df, feat = preprocess(PATH)
new_df = reindex(df, bipartite)

empty = np.zeros(feat.shape[1])[np.newaxis, :]
feat = np.vstack([empty, feat])

max_idx = max(new_df.u.max(), new_df.i.max())
rand_feat = np.zeros((max_idx + 1, 172))

new_df.to_csv(OUT_DF)
np.save(OUT_FEAT, feat)
np.save(OUT_NODE_FEAT, rand_feat)

# === code from twitter-research-tgn end ===

# If you have new dataset follow by same format in Jodie,
# you can directly use name to retrieve dataset

def TemporalDataset(dataset):
if not os.path.exists('./data/{}.bin'.format(dataset)):
if not os.path.exists('./data/{}.csv'.format(dataset)):
if not os.path.exists('./data'):
os.mkdir('./data')

url = 'https://snap.stanford.edu/jodie/{}.csv'.format(dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open("./data/{}.csv".format(dataset), "wb") as handle:
handle.write(data.read())

print("Start Process Data ...")
run(dataset)
raw_connection = pd.read_csv('./data/ml_{}.csv'.format(dataset))
raw_feature = np.load('./data/ml_{}.npy'.format(dataset))
# -1 for re-index the node
src = raw_connection['u'].to_numpy()-1
dst = raw_connection['i'].to_numpy()-1
# Create directed graph
g = dgl.graph((src, dst))
g.edata['timestamp'] = torch.from_numpy(
raw_connection['ts'].to_numpy())
g.edata['label'] = torch.from_numpy(raw_connection['label'].to_numpy())
g.edata['feats'] = torch.from_numpy(raw_feature[1:, :]).float()
dgl.save_graphs('./data/{}.bin'.format(dataset), [g])
else:
print("Data is exist directly loaded.")
gs, _ = dgl.load_graphs('./data/{}.bin'.format(dataset))
g = gs[0]
return g

def TemporalWikipediaDataset():
# Download the dataset
return TemporalDataset('wikipedia')

def TemporalRedditDataset():
return TemporalDataset('reddit')