Skip to content

Commit

Permalink
Update atoms, old prototype back.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Apr 18, 2024
1 parent edf40f4 commit c1a632e
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 16 deletions.
119 changes: 105 additions & 14 deletions jarvis/core/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,106 @@ def density(self):
)
return den

def plot_atoms(
self=None,
colors=[],
sizes=[],
cutoff=1.9,
opacity=0.5,
bond_width=2,
filename=None,
):
"""Plot atoms using plotly."""
import plotly.graph_objects as go

fig = go.Figure()
if not colors:
colors = ["blue", "green", "red"]

unique_elements = self.uniq_species
if len(unique_elements) > len(colors):
raise ValueError("Provide more colors.")
color_map = {}
size_map = {}
for ii, i in enumerate(unique_elements):
color_map[i] = colors[ii]
size_map[i] = Specie(i).Z * 2
cart_coords = self.cart_coords
elements = self.elements
atoms_arr = []

for ii, i in enumerate(cart_coords):
atoms_arr.append(
[
i[0],
i[1],
i[2],
color_map[elements[ii]],
size_map[elements[ii]],
]
)
# atoms = [
# (0, 0, 0, 'red'), # Atom 1
# (1, 1, 1, 'blue'), # Atom 2
# (2, 0, 1, 'green') # Atom 3
# ]

# Create a scatter plot for the 3D points
trace1 = go.Scatter3d(
x=[atom[0] for atom in atoms_arr],
y=[atom[1] for atom in atoms_arr],
z=[atom[2] for atom in atoms_arr],
mode="markers",
marker=dict(
size=[atom[4] for atom in atoms_arr], # Marker size
color=[atom[3] for atom in atoms_arr], # Marker color
opacity=opacity,
),
)
fig.add_trace(trace1)

# Update plot layout
fig.update_layout(
title="3D Atom Coordinates",
scene=dict(
xaxis_title="X Coordinates",
yaxis_title="Y Coordinates",
zaxis_title="Z Coordinates",
),
margin=dict(l=0, r=0, b=0, t=0), # Tight layout
)
if bond_width is not None:
nbs = self.get_all_neighbors(r=5)
bonds = []
for i in nbs:
for j in i:
if j[2] <= cutoff:
bonds.append([j[0], j[1]])
for bond in bonds:
# print(bond)
# Extract coordinates of the first and second atom in each bond
x_coords = [atoms_arr[bond[0]][0], atoms_arr[bond[1]][0]]
y_coords = [atoms_arr[bond[0]][1], atoms_arr[bond[1]][1]]
z_coords = [atoms_arr[bond[0]][2], atoms_arr[bond[1]][2]]

fig.add_trace(
go.Scatter3d(
x=x_coords,
y=y_coords,
z=z_coords,
mode="lines",
line=dict(color="grey", width=bond_width),
marker=dict(
size=0.1
), # Small marker size to make lines prominent
)
)
# Show the plot
if filename is not None:
fig.write_image(filename)
else:
fig.show()

@property
def atomic_numbers(self):
"""Get list of atomic numbers of atoms in the atoms object."""
Expand Down Expand Up @@ -1232,7 +1332,9 @@ def hook(model, input, output):
h = get_val(model, g, lg)
return h

def get_prototype_name(self, prim=True, include_c_over_a=False, digits=3):
def get_mineral_prototype_name(
self, prim=True, include_c_over_a=False, digits=3
):
from jarvis.analysis.structure.spacegroup import Spacegroup3D

spg = Spacegroup3D(self)
Expand All @@ -1242,7 +1344,7 @@ def get_prototype_name(self, prim=True, include_c_over_a=False, digits=3):
# hall_number=str(spg._dataset['hall_number'])
wyc = "".join(list((sorted(set(spg._dataset["wyckoffs"])))))
name = (
(self.composition.prototype)
(self.composition.prototype_new)
+ "_"
+ str(number)
+ "_"
Expand Down Expand Up @@ -1272,24 +1374,13 @@ def get_minaral_name(self, model=""):
if maem < mae:
mae = maem
name = i[0]
print("name1", name, maem)
else:
# mineral = {}
for i, j in mineral_json_file.items():
for k in j:
maem = mean_absolute_error(k[1], feats)
# mineral[k[0]]=mean_absolute_error(k[1],feats)
if maem < mae:
mae = maem
name = k[0] # mineral[k[0]]
print("name2", name, maem)
# mem=[]
# for i,j in mineral_json_file.items():
# for k in j:
# print(j[1],feats)
# mem.append(j[0],np.linalg.norm(j[1],feats))
# print(mem)
# s=sorted(mem,key=1)[0]
name = k[0]
return name

def lattice_points_in_supercell(self, supercell_matrix):
Expand Down
17 changes: 17 additions & 0 deletions jarvis/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ def prototype(self):
"""Get chemical prototypes such as A, AB etc."""
proto = ""
all_upper = string.ascii_uppercase

reduced, repeat = self.reduce()
N = 0
for specie, count in reduced.items():
if count != 1:
proto = proto + str(all_upper[N]) + str(round(count, 3))
N = N + 1
else:
proto = proto + str(all_upper[N])
N = N + 1
return proto # .replace("1", "")

@property
def prototype_new(self):
"""Get chemical prototypes such as A, AB etc."""
proto = ""
all_upper = string.ascii_uppercase
# print('reduce',self.reduce())
reduced, repeat = self.reduce()
items = sorted(list(reduced.values()), reverse=True)
Expand Down
2 changes: 1 addition & 1 deletion jarvis/tests/testfiles/core/test_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_basic_atoms():
print(opt.from_optimade(opt_info))

polar = Si.check_polar
prot = Si.get_prototype_name()
# prot = Si.get_prototype_name()
Si.props = ["a", "a"]
vac_pad = VacuumPadding(Si)
den_2d = round(vac_pad.get_effective_2d_slab().density, 2)
Expand Down
3 changes: 2 additions & 1 deletion jarvis/tests/testfiles/core/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ def test_comp():
cc = Composition(comp)

assert (
cc.prototype,
cc.prototype_new,
cc.formula,
cc.reduced_formula,
round(cc.weight, 4),
cc.to_dict(),
) == ("A2B", "Li2O4", "LiO2", 77.8796, comp)
# ) == ("AB2", "Li2O4", "LiO2", 77.8796, comp)
assert cc.prototype == "AB2"
c = Composition.from_string("Al2O3Al5Co6O1")
td = c.to_dict()
fd = Composition.from_dict(td)
Expand Down

0 comments on commit c1a632e

Please sign in to comment.