Knowledge graph embeddings: Training in PyG, prediction with GDS
This Jupyter notebook is hosted here in the Neo4j Graph Data Science Client Github repository.
The notebook demonstrates how to use the graphdatascience
and
PyTorch Geometric (PyG) Python libraries to:
-
Import the FB15k-237 dataset directly into GDS
-
Train a TransE model with PyG
-
Make predictions on the data in the database using GDS Knowledge Graph Embeddings functionality
1. Prerequisites
To run this notebook, you’ll need a Neo4j server with a recent GDS version (2.5+ or later) installed.
Additionally, the following Python libraries are required:
-
graphdatascience
, see documentation for installation instructions -
pytorch-geometric
version >= 2.5.0, see PyG documentation for installation instructions
2. Setup
We’ll begin by importing our dependencies and establishing a GDS client connection to the database.
%pip install graphdatascience torch torch_geometric
import collections
import os
import pandas as pd
import torch
import torch.optim as optim
from torch_geometric.data import Data, download_url
from torch_geometric.nn import TransE
from tqdm import tqdm
from graphdatascience import GraphDataScience
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
NEO4J_AUTH = None
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
NEO4J_AUTH = (
os.environ.get("NEO4J_USER"),
os.environ.get("NEO4J_PASSWORD"),
)
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB)
# This notebook requires GDS 2.5.0 or later
assert gds.version() >= "2.5.0"
3. Downloading and Storing the FB15k-237 Dataset in the Database
Download the FB15k-237 dataset Extract the required files: train.txt, valid.txt, and test.txt.
import os
import zipfile
url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
raw_dir = "./data_from_zip"
download_url(f"{url}", raw_dir)
raw_file_names = ["train.txt", "valid.txt", "test.txt"]
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
for filename in raw_file_names:
zip_ref.extract(f"Release/{filename}", path=raw_dir)
data_dir = raw_dir + "/" + "Release"
Set a constraint for unique id entries to speed up data uploads.
gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
Creating Entity Nodes: Create a node with the label Entity
. This
node should have properties id
and text
. - Syntax:
(:Entity {id: int, text: str})
Creating Relationships for Training with PyG: Based on the training
stage, create relationships of type TRAIN
, TEST
, or VALID
.
Each of these relationships should have a rel_id
property. - Example
Syntax: [:TRAIN {rel_id: int}]
Creating Relationships for Prediction with GDS: For the prediction
stage, create relationships of a specific type denoted as REL_i
.
Each of these relationships should have rel_id
and text
properties. - Example Syntax: [:REL_7 {rel_id: int, text: str}]
rel_types = {
"train.txt": "TRAIN",
"valid.txt": "VALID",
"test.txt": "TEST",
}
rel_id_to_text_dict = {}
rel_type_dict = collections.defaultdict(list)
rel_dict = {}
def process():
node_dict_ = {}
for file_name in raw_file_names:
file_name_path = data_dir + "/" + file_name
with open(file_name_path, "r") as f:
data = [x.split("\t") for x in f.read().split("\n")[:-1]]
list_of_dicts = []
for i, (src, rel, dst) in enumerate(data):
if src not in node_dict_:
node_dict_[src] = len(node_dict_)
if dst not in node_dict_:
node_dict_[dst] = len(node_dict_)
if rel not in rel_dict:
rel_dict[rel] = len(rel_dict)
rel_id_to_text_dict[rel_dict[rel]] = rel
source = node_dict_[src]
target = node_dict_[dst]
edge_type = rel_dict[rel]
rel_type_dict[edge_type].append(
{
"source": source,
"target": target,
}
)
list_of_dicts.append(
{
"source": source,
"source_text": src,
"target": target,
"target_text": dst,
"rel_id": edge_type,
}
)
rel_type = rel_types[file_name]
print(f"Writing {len(list_of_dicts)} entities of {rel_type}")
gds.run_cypher(
f"""
UNWIND $ll as l
MERGE (n:Entity {{id:l.source, text:l.source_text}})
MERGE (m:Entity {{id:l.target, text:l.target_text}})
MERGE (n)-[:{rel_type} {{rel_id:l.rel_id}}]->(m)
""",
params={"ll": list_of_dicts},
)
print("Writing relationships as different relationship types")
for rel_id, rels in tqdm(rel_type_dict.items()):
REL_TYPE = f"REL_{rel_id}"
gds.run_cypher(
f"""
UNWIND $ll AS l MATCH (n:Entity {{id:l.source}}), (m:Entity {{id:l.target}})
MERGE (n)-[:{REL_TYPE} {{rel_id:$rel_id, text:$text}}]->(m)
""",
params={"ll": rels, "rel_id": rel_id, "text": rel_id_to_text_dict[rel_id]},
)
process()
Project all data in graph to get mapping between id
and internal
nodeId
field from database.
node_projection = {"Entity": {"properties": "id"}}
relationship_projection = [
{"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}},
{"TEST": {"orientation": "NATURAL", "properties": "rel_id"}},
{"VALID": {"orientation": "NATURAL", "properties": "rel_id"}},
]
ttv_G, result = gds.graph.project(
"fb15k-graph-ttv",
node_projection,
relationship_projection,
)
node_properties = gds.graph.nodeProperties.stream(
ttv_G,
["id"],
separate_property_columns=True,
)
nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))
id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))
4. Training the TransE Model with PyG
Retrieve data from the database, convert it into torch tensors, and
format it into a Data
structure suitable for training with PyG.
def create_data_from_graph(relationship_type):
rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
topology = [
rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
]
edge_index = torch.tensor(topology, dtype=torch.long)
edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
data = Data(edge_index=edge_index, edge_type=edge_type)
data.num_nodes = len(nodeId_to_id)
display(data)
return data
train_tensor_data = create_data_from_graph("TRAIN")
test_tensor_data = create_data_from_graph("TEST")
val_tensor_data = create_data_from_graph("VALID")
Drop the projected graph to save memory.
gds.graph.drop(ttv_G)
The training process of the TransE model follows the corresponding PyG example.
def train_model_with_pyg():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransE(
num_nodes=train_tensor_data.num_nodes,
num_relations=train_tensor_data.num_edge_types,
hidden_channels=50,
).to(device)
loader = model.loader(
head_index=train_tensor_data.edge_index[0],
rel_type=train_tensor_data.edge_type,
tail_index=train_tensor_data.edge_index[1],
batch_size=1000,
shuffle=True,
)
optimizer = optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
total_loss = total_examples = 0
for head_index, rel_type, tail_index in loader:
optimizer.zero_grad()
loss = model.loss(head_index, rel_type, tail_index)
loss.backward()
optimizer.step()
total_loss += float(loss) * head_index.numel()
total_examples += head_index.numel()
return total_loss / total_examples
@torch.no_grad()
def test(data):
model.eval()
return model.test(
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=1000,
k=10,
)
# Consider increasing the number of epochs
epoch_count = 5
for epoch in range(1, epoch_count):
loss = train()
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
if epoch % 75 == 0:
rank, hits = test(val_tensor_data)
print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
torch.save(model, f"./model_{epoch_count}.pt")
mean_rank, mrr, hits_at_k = test(test_tensor_data)
print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
return model
model = train_model_with_pyg()
# The model can be loaded if it was trained before
# model = torch.load("./model_501.pt")
Extract node embeddings from the trained model and put them into database.
for i in tqdm(range(len(nodeId_to_id))):
gds.run_cypher(
"MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING",
params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()},
)
5. Predict Using GDS Knowledge Graph Edge Embeddings Functionality
Select a relationship type for which to make predictions.
relationship_to_predict = "/film/film/genre"
rel_id_to_predict = rel_dict[relationship_to_predict]
rel_label_to_predict = f"REL_{rel_id_to_predict}"
Project the graph with all nodes and existing relationships of the selected type.
G_test, result = gds.graph.project(
"graph_to_predict_",
{"Entity": {"properties": ["id", "emb"]}},
rel_label_to_predict,
)
def print_graph_info(G):
print(f"Graph '{G.name()}' node count: {G.node_count()}")
print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}")
print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")
print_graph_info(G_test)
Retrieve the embedding for the selected relationship from the PyG model. Then, create a GDS TransE model using the graph, node embeddings property, and the embedding for the relationship to be predicted.
target_emb = model.node_emb.weight[rel_id_to_predict].tolist()
transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb})
source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
source_ids_df = gds.run_cypher(
"UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
params={"node_text_list": source_node_list},
)
Now, we can use the model to make prediction.
result = transe_model.predict_stream(
source_node_filter=source_ids_df.nodeId,
target_node_filter="Entity",
relationship_type=rel_label_to_predict,
top_k=3,
concurrency=4,
)
print(result)
Augment the predicted result with node identifiers and their text values.
ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))
ids_to_text = gds.run_cypher(
"UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
params={"ids": ids_in_result},
)
nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))
result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))
print(result)
6. Using Write Mode
Write mode allows you to write results directly to the database as a new
relationship type. This approach helps to avoid mapping from nodeId
to id
.
write_relationship_type = "PREDICTED_" + rel_label_to_predict
result_write = transe_model.predict_write(
source_node_filter=source_ids_df.nodeId,
target_node_filter="Entity",
relationship_type=rel_label_to_predict,
write_relationship_type=write_relationship_type,
write_property="transe_score",
top_k=3,
concurrency=4,
)
Extract the result from the database.
gds.run_cypher(
"MATCH (n)-[r:"
+ write_relationship_type
+ "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
)
gds.graph.drop(G_test)