Skip to content

异构图神经网络HAN。Heterogeneous Graph Attention Network (HAN) with pytorch

Notifications You must be signed in to change notification settings

taishan1994/pytorch_HAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch_HAN

一位热心git友指出之前问题是过拟合了,修改以下utils.py相关地方得到以下结果:

1.image 2.image


Paper address:
https://github.com/Jhy1993/Representation-Learning-on-Heterogeneous-Graph
Heterogeneous Graph Attention Network (HAN) with pytorch. If you want to pursue the performance in the original paper, this may not be suitable for you, because there is still a problem: training loss decreases, but verification loss increases.

If you just want to figure out the basic principles of HAN and how to change tensorflow code to pytorch code, then this is for you. I implemented it according to the original tensorflow code structure.

If you want to pursue higher performance, please refer to:
Official tensorflow implementation:https://github.com/Jhy1993/HAN
DGL implementation:https://github.com/dmlc/dgl/tree/master/examples/pytorch/han

The result

Address of ACM data set: Preprocessed ACM can be found in: https://pan.baidu.com/s/1V2iOikRqHPtVvaANdkzROw 提取码:50k2

You can use the following command to run:

python main.py

Training result:

600 300 2125
y_train:(3025, 3), y_val:(3025, 3), y_test:(3025, 3), train_idx:(1, 600), val_idx:(1, 300), test_idx:(1, 2125)
2
model: pre_trained/acm/acm_allMP_multi_fea_.ckpt
fea_list[0].shape torch.Size([1, 1870, 3025])
biases_list[0].shape: torch.Size([1, 3025, 3025])
3
2
torch.Size([1, 1870, 3025]) torch.Size([1, 3025, 3025])
torch.Size([1, 1870, 3025]) torch.Size([1, 3025, 3025])
训练节点个数600
验证节点个数300
测试节点个数2125
epoch:001, loss:1.1004, TrainAcc:0.3517, ValLoss:1.1022, ValAcc:0.4000
epoch:002, loss:1.0762, TrainAcc:0.4250, ValLoss:1.1980, ValAcc:0.0533
epoch:003, loss:1.0007, TrainAcc:0.6300, ValLoss:1.4572, ValAcc:0.0533
epoch:004, loss:0.8876, TrainAcc:0.6583, ValLoss:2.0040, ValAcc:0.0500
epoch:005, loss:0.8145, TrainAcc:0.6350, ValLoss:2.7091, ValAcc:0.0500
epoch:006, loss:0.7897, TrainAcc:0.6267, ValLoss:3.2186, ValAcc:0.0500
epoch:007, loss:0.7804, TrainAcc:0.6150, ValLoss:3.4550, ValAcc:0.0500
epoch:008, loss:0.7527, TrainAcc:0.6150, ValLoss:3.5096, ValAcc:0.0500
epoch:009, loss:0.7404, TrainAcc:0.6117, ValLoss:3.5125, ValAcc:0.0600
epoch:010, loss:0.7329, TrainAcc:0.6633, ValLoss:3.5349, ValAcc:0.0400
epoch:011, loss:0.7169, TrainAcc:0.6983, ValLoss:3.5743, ValAcc:0.0133
epoch:012, loss:0.6934, TrainAcc:0.6917, ValLoss:3.6612, ValAcc:0.0033
epoch:013, loss:0.6711, TrainAcc:0.6750, ValLoss:3.7738, ValAcc:0.0033
epoch:014, loss:0.6645, TrainAcc:0.6733, ValLoss:3.9418, ValAcc:0.0200
epoch:015, loss:0.6652, TrainAcc:0.6833, ValLoss:4.0934, ValAcc:0.0300
epoch:016, loss:0.6515, TrainAcc:0.6883, ValLoss:4.2498, ValAcc:0.0300
epoch:017, loss:0.6238, TrainAcc:0.7050, ValLoss:4.4304, ValAcc:0.0300
epoch:018, loss:0.6082, TrainAcc:0.7317, ValLoss:4.5820, ValAcc:0.0333
epoch:019, loss:0.6030, TrainAcc:0.7517, ValLoss:4.7110, ValAcc:0.0367
epoch:020, loss:0.5933, TrainAcc:0.7850, ValLoss:4.8053, ValAcc:0.0400
epoch:021, loss:0.5824, TrainAcc:0.8267, ValLoss:4.8781, ValAcc:0.0333
epoch:022, loss:0.5655, TrainAcc:0.8017, ValLoss:4.9006, ValAcc:0.0267
epoch:023, loss:0.5333, TrainAcc:0.8083, ValLoss:4.9148, ValAcc:0.0167
epoch:024, loss:0.5175, TrainAcc:0.8050, ValLoss:4.8788, ValAcc:0.0100
epoch:025, loss:0.4994, TrainAcc:0.8117, ValLoss:4.7670, ValAcc:0.0033
epoch:026, loss:0.4888, TrainAcc:0.8333, ValLoss:4.5965, ValAcc:0.0033

This is where the problem lies.
If you know how to solve this problem, please don't hesitate to tell me.