-
Notifications
You must be signed in to change notification settings - Fork 1
/
t-sne.py
41 lines (26 loc) · 892 Bytes
/
t-sne.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
default_n_threads = 8
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
os.environ['MKL_NUM_THREADS'] = f"{default_n_threads}"
os.environ['OMP_NUM_THREADS'] = f"{default_n_threads}"
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import argparse
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('-k', type=str, default=None)
args = parser.parse_args()
key = args.k
path = f'saving/MIND-small/{key}/export_states/item_embeds.npy'
embeds = np.load(path)
# 使用T-SNE将数据降维到二维
tsne = TSNE(n_components=2, random_state=0)
embeds_2d = tsne.fit_transform(embeds)
# 绘制二维分布图
plt.figure(figsize=(8, 6))
plt.scatter(embeds_2d[:, 0], embeds_2d[:, 1])
plt.xlabel('T-SNE feature 1')
plt.ylabel('T-SNE feature 2')
plt.title('T-SNE Visualization of Embeddings')
plt.savefig('t-sne.png')
plt.close()