-
Notifications
You must be signed in to change notification settings - Fork 12
/
visdom.py
116 lines (98 loc) · 4.48 KB
/
visdom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import visdom
import numpy as np
class RelaPoseTmp:
def __init__(self, legend_tag='', viswin=None, visenv=None, vishost='localhost', visport=8097):
self.visman = VisManager(visenv, host=vishost, port=visport)
# Intialize windows with name
loss_win_tag = '{} Loss'.format(viswin) if viswin else 'Loss'
acc_win_tag = '{} Val'.format(viswin) if viswin else 'Val'
pos_legend = 'pos_{}'.format(legend_tag)
rot_legend = 'rot_{}'.format(legend_tag)
win_dict = {loss_win_tag : [legend_tag], acc_win_tag : [pos_legend, rot_legend]}
self.visman.set_wins(win_dict)
self.loss_meter = self.visman.get_meter(loss_win_tag, legend_tag)
self.pos_meter = self.visman.get_meter(acc_win_tag, pos_legend)
self.rot_meter = self.visman.get_meter(acc_win_tag, rot_legend)
def get_meters(self):
return self.loss_meter, self.pos_meter, self.rot_meter
def save_state(self):
self.visman.save_state()
class VisLineMeter:
'''Visdom Line Data Meter'''
def __init__(self, server, env, win, legend):
self.server = server
self.env = env
self.win = win
self.legend = legend
self.style_opts = self.get_style_opts_()
def get_style_opts_(self):
layout = {'plotly': dict(title=self.win, xaxis={'title': 'epochs'})}
style_opts=dict(mode='lines', showlegend=True, layoutopts=layout)
#style_opts=dict(mode='marker+lines',
# markersize=5,
# markersymbol='dot',
# markers={'line': {'width': 0.5}},
# showlegend=True, layoutopts=layout)
return style_opts
def validate_input_(self, X):
if isinstance(X, np.ndarray):
return X
elif isinstance(X, int) or isinstance(X, float) or isinstance(X, np.float32):
return np.array([X])
elif isinstance(X, torch.Tensor):
X = X.cpu().data.numpy()
if X.ndim == 0:
X = X.reshape((1))
return X
def update(self, X, Y):
if self.server:
self.server.line(X=self.validate_input_(X),
Y=self.validate_input_(Y),
env=self.env, win=self.win, name=self.legend,
opts=self.style_opts, update='append')
def clear(self):
self.server.line(X=None, Y=None, env=self.env, win=self.win, name=self.legend, update='remove')
def __repr__(self):
return 'Visdom meter(env={}, win={}, legend={})'.format(self.env, self.win, self.legend)
class VisManager:
"""Visdom manager
Initialize connection to the running visdom server.
Create windows with style to plot data.
Maintain window creation(incl. window style, data meters),
window state saving and clear.
"""
def __init__(self, env, host='localhost', port='8097'):
self.env = env
if self.env is None:
self.server = None
print('Visdom is not set..')
else:
self.dummy = False
host = 'http://{}'.format(host)
self.server = visdom.Visdom(server=host, port=port)
assert self.server.check_connection(), 'Visdom server is not active on server {}:{}'.format(host, port)
print('Visdom server connected on {}:{}'.format(host, port))
self.win_pool = {}
def set_wins(self, win_dict):
'''win_dict: {win_name : [legend_name]}'''
for win_name in win_dict:
self.win_pool[win_name] = {}
for legend_name in win_dict[win_name]:
meter = VisLineMeter(self.server, self.env, win_name, legend_name)
self.win_pool[win_name][legend_name] = meter
print('Initialize data meters {}'.format(str(meter)))
def get_meter(self, win_name, legend_name):
return self.win_pool[win_name][legend_name]
def save_state(self):
if self.server:
self.server.save(envs=[self.env])
def clear_all(self):
for win_name in self.win_pool:
for legend_name in self.win_pool[win_name]:
self.win_pool[win_name][legend_name].clear()
def print_(self):
print('Visdom Manager Window Pool:\n')
for win_name in self.win_pool:
for legend_name in self.win_pool[win_name]:
print(self.win_pool[win_name][legend_name])