-
Notifications
You must be signed in to change notification settings - Fork 96
/
simple_print.lua
124 lines (114 loc) · 3.24 KB
/
simple_print.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
local function removeNodeFromEdges(node_id, edges)
local from_nodes = {}
local to_nodes = {}
-- remove edges
local idx = 1
while idx <= #edges do
local edge = edges[idx]
if edge.source == node_id then
local to_node = edges[idx].target
table.insert(to_nodes, to_node)
table.remove(edges, idx)
elseif edge.target == node_id then
local from_node = edges[idx].source
table.insert(from_nodes, from_node)
table.remove(edges, idx)
else
idx = idx 1
end
end
-- add new edges
for _, f in pairs(from_nodes) do
for _, t in pairs(to_nodes) do
local edge = {source = f, target= t}
table.insert(edges, edge)
end
end
return edges
end
local function isNodeGood(node)
return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
end
local function reIndexNodes(nodes, edges)
-- make reverse map
local rev_map = {}
for idx = 1, #nodes do
rev_map[nodes[idx].id] = idx
nodes[idx].id = idx
end
for idx = 1, #edges do
local edge = edges[idx]
edge.source = rev_map[edge.source]
edge.target = rev_map[edge.target]
end
return nodes, edges
end
local function cleanGraph(nodes, edges)
local idx = 1
while idx <= #nodes do
local node = nodes[idx]
if isNodeGood(node.orig_node) then
idx = idx 1
else
local id = node.id
table.remove(nodes, idx)
edges = removeNodeFromEdges(id, edges)
end
end
return reIndexNodes(nodes, edges)
end
local function loadGraph(graph)
local nodes = {}
local edges = {}
for _, node in ipairs(graph.nodes) do
local idx = node.id
table.insert(nodes, {id=idx, orig_node = node} )
for ich = 1, #node.children do
table.insert( edges, {source = idx, target = node.children[ich].id})
end
end
nodes, edges = cleanGraph(nodes, edges)
return nodes , edges
end
local M = {}
function M.todot( graph, title )
local nodes, edges = loadGraph(graph)
local str = {}
table.insert(str,'digraph G {\n')
if title then
table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
end
table.insert(str,'node [shape = oval]; ')
local nodelabels = {}
for i,node in ipairs(nodes) do
local true_node = node.orig_node
local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
nodelabels[i] = 'n' .. true_node.id
table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
end
table.insert(str,'\n')
for i,edge in ipairs(edges) do
table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
end
table.insert(str,'}')
return table.concat(str,'')
end
function M.dot(g,title,fname)
local gv = M.todot(g, title)
local fngv = (fname or os.tmpname()) .. '.dot'
local fgv = io.open(fngv,'w')
fgv:write(gv)
fgv:close()
local fnsvg = (fname or os.tmpname()) .. '.svg'
os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
if not fname then
require 'qtsvg'
local qs = qt.QSvgWidget(fnsvg)
qs:show()
os.remove(fngv)
os.remove(fnsvg)
-- print(fngv,fnpng)
return qs
end
end
return M