Skip to content

Commit

Permalink
Merge pull request #938 from c-bata/migrate-plot-parallel-coordinate-…
Browse files Browse the repository at this point in the history
…to-tslib

Move parallel coordinate plot to `@optuna/react`
  • Loading branch information
c-bata authored Aug 19, 2024
2 parents a0dd548 + 03360d8 commit 610b453
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 279 deletions.
283 changes: 4 additions & 279 deletions optuna_dashboard/ts/components/GraphParallelCoordinate.tsx
Original file line number Diff line number Diff line change
@@ -1,104 +1,24 @@
import {
Checkbox,
FormControlLabel,
FormGroup,
Grid,
Typography,
useTheme,
} from "@mui/material"
import {
GraphContainer,
PlotParallelCoordinate,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
import {
Target,
useFilteredTrials,
useObjectiveAndUserAttrTargets,
useParamTargets,
} from "@optuna/react"
import * as Optuna from "@optuna/types"
import * as plotly from "plotly.js-dist-min"
import React, { FC, ReactNode, useEffect, useState } from "react"
import { SearchSpaceItem, StudyDetail } from "ts/types/optuna"
import React, { FC, useEffect } from "react"
import { StudyDetail } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { usePlot } from "../hooks/usePlot"
import { usePlotlyColorTheme } from "../state"
import { useBackendRender } from "../state"

const plotDomId = "graph-parallel-coordinate"

const useTargets = (
study: StudyDetail | null
): [Target[], SearchSpaceItem[], () => ReactNode] => {
const [targets1] = useObjectiveAndUserAttrTargets(study)
const searchSpace = useMergedUnionSearchSpace(study?.union_search_space)
const [targets2] = useParamTargets(searchSpace)
const [checked, setChecked] = useState<boolean[]>([true])

const allTargets = [...targets1, ...targets2]
useEffect(() => {
if (allTargets.length !== checked.length) {
setChecked(
allTargets.map((t) => {
if (t.kind === "user_attr") {
return false
}
if (t.kind !== "params" || study === null) {
return true
}
// By default, params that is not included in intersection search space should be disabled,
// otherwise all trials are filtered.
return (
study.intersection_search_space.find((s) => s.name === t.key) !==
undefined
)
})
)
}
}, [allTargets])

const handleOnChange = (event: React.ChangeEvent<HTMLInputElement>) => {
setChecked(
checked.map((c, i) =>
i.toString() === event.target.name ? event.target.checked : c
)
)
}

const renderCheckBoxes = (): ReactNode => (
<FormGroup>
{allTargets.map((t, i) => {
return (
<FormControlLabel
key={i}
control={
<Checkbox
checked={checked.length > i ? checked[i] : true}
onChange={handleOnChange}
name={i.toString()}
/>
}
label={t.toLabel(study?.objective_names)}
/>
)
})}
</FormGroup>
)

const targets = allTargets.filter((t, i) =>
checked.length > i ? checked[i] : true
)
return [targets, searchSpace, renderCheckBoxes]
}

export const GraphParallelCoordinate: FC<{
study: StudyDetail | null
}> = ({ study = null }) => {
if (useBackendRender()) {
return <GraphParallelCoordinateBackend study={study} />
} else {
return <GraphParallelCoordinateFrontend study={study} />
return <PlotParallelCoordinate study={study} />
}
}

Expand Down Expand Up @@ -135,198 +55,3 @@ const GraphParallelCoordinateBackend: FC<{
/>
)
}

const GraphParallelCoordinateFrontend: FC<{
study: StudyDetail | null
}> = ({ study = null }) => {
const { graphComponentState, notifyGraphDidRender } = useGraphComponentState()

const theme = useTheme()
const colorTheme = usePlotlyColorTheme(theme.palette.mode)

const [targets, searchSpace, renderCheckBoxes] = useTargets(study)

const trials = useFilteredTrials(study, targets, false)
useEffect(() => {
if (study !== null && graphComponentState !== "componentWillMount") {
plotCoordinate(study, trials, targets, searchSpace, colorTheme)?.then(
notifyGraphDidRender
)
}
}, [study, trials, targets, searchSpace, colorTheme, graphComponentState])

return (
<Grid container direction="row">
<Grid
item
xs={3}
container
direction="column"
sx={{
paddingRight: theme.spacing(2),
display: "flex",
flexDirection: "column",
}}
>
<Typography
variant="h6"
sx={{ margin: "1em 0", fontWeight: theme.typography.fontWeightBold }}
>
Parallel Coordinate
</Typography>
{renderCheckBoxes()}
</Grid>
<Grid item xs={9}>
<GraphContainer
plotDomId={plotDomId}
graphComponentState={graphComponentState}
/>
</Grid>
</Grid>
)
}

const plotCoordinate = (
study: StudyDetail,
trials: Optuna.Trial[],
targets: Target[],
searchSpace: SearchSpaceItem[],
colorTheme: Partial<Plotly.Template>
) => {
if (document.getElementById(plotDomId) === null) {
return
}

const layout: Partial<plotly.Layout> = {
margin: {
l: 70,
t: 50,
r: 50,
b: 100,
},
template: colorTheme,
uirevision: "true",
}
if (trials.length === 0 || targets.length === 0) {
return plotly.react(plotDomId, [], layout)
}

const maxLabelLength = 40
const breakLength = maxLabelLength / 2
const ellipsis = "…"
const truncateLabelIfTooLong = (originalLabel: string): string => {
return originalLabel.length > maxLabelLength
? originalLabel.substring(0, maxLabelLength - ellipsis.length) + ellipsis
: originalLabel
}
const breakLabelIfTooLong = (originalLabel: string): string => {
const truncated = truncateLabelIfTooLong(originalLabel)
return truncated
.split("")
.map((c, i) => {
return (i + 1) % breakLength === 0 ? c + "<br>" : c
})
.join("")
}

const calculateLogScale = (values: number[]) => {
const logValues = values.map((v) => {
return Math.log10(v)
})
const minValue = Math.min(...logValues)
const maxValue = Math.max(...logValues)
const range = [Math.floor(minValue), Math.ceil(maxValue)]
const tickvals = Array.from(
{ length: Math.ceil(maxValue) - Math.floor(minValue) + 1 },
(_, i) => i + Math.floor(minValue)
)
const ticktext = tickvals.map((x) => `${Math.pow(10, x).toPrecision(3)}`)
return { logValues, range, tickvals, ticktext }
}

const dimensions = targets.map((target) => {
if (target.kind === "objective" || target.kind === "user_attr") {
const values: number[] = trials.map(
(t) => target.getTargetValue(t) as number
)
return {
label: target.toLabel(study.objective_names),
values: values,
range: [Math.min(...values), Math.max(...values)],
}
} else {
const s = searchSpace.find(
(s) => s.name === target.key
) as SearchSpaceItem // Must be already filtered.

const values: number[] = trials.map(
(t) => target.getTargetValue(t) as number
)
if (s.distribution.type === "CategoricalDistribution") {
// categorical
const vocabArr: string[] = s.distribution.choices.map(
(c) => c?.toString() ?? "null"
)
const tickvals: number[] = vocabArr.map((v, i) => i)
return {
label: breakLabelIfTooLong(s.name),
values: values,
range: [0, s.distribution.choices.length - 1],
// @ts-ignore
tickvals: tickvals,
ticktext: vocabArr,
}
} else if (s.distribution.log) {
// numerical and log
const { logValues, range, tickvals, ticktext } =
calculateLogScale(values)
return {
label: breakLabelIfTooLong(s.name),
values: logValues,
range,
tickvals,
ticktext,
}
} else {
// numerical and linear
return {
label: breakLabelIfTooLong(s.name),
values: values,
range: [s.distribution.low, s.distribution.high],
}
}
}
})
if (dimensions.length === 0) {
console.log("Must not reach here.")
return plotly.react(plotDomId, [], layout)
}
let reversescale = false
if (
targets[0].kind === "objective" &&
(targets[0].getObjectiveId() as number) < study.directions.length &&
study.directions[targets[0].getObjectiveId() as number] === "maximize"
) {
reversescale = true
}
const plotData: Partial<plotly.PlotData>[] = [
{
type: "parcoords",
dimensions: dimensions,
labelangle: 30,
labelside: "bottom",
line: {
color: dimensions[0]["values"],
// @ts-ignore
colorscale: "Blues",
colorbar: {
title: targets[0].toLabel(study.objective_names),
},
showscale: true,
reversescale: reversescale,
},
},
]

return plotly.react(plotDomId, plotData, layout)
}
53 changes: 53 additions & 0 deletions tslib/react/src/components/PlotParallelCoordinate.stories.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { CssBaseline, ThemeProvider } from "@mui/material"
import { Meta, StoryObj } from "@storybook/react"
import React from "react"
import { useMockStudy } from "../MockStudies"
import { darkTheme } from "../styles/darkTheme"
import { lightTheme } from "../styles/lightTheme"
import { PlotParallelCoordinate } from "./PlotParallelCoordinate"

const meta: Meta<typeof PlotParallelCoordinate> = {
component: PlotParallelCoordinate,
title: "Plot/ParallelCoordinate",
tags: ["autodocs"],
decorators: [
(Story, storyContext) => {
const { study } = useMockStudy(storyContext.parameters?.studyId)
if (!study) return <p>loading...</p>
return (
<ThemeProvider theme={storyContext.parameters?.theme}>
<CssBaseline />
<Story
args={{
study,
}}
/>
</ThemeProvider>
)
},
],
}

export default meta
type Story = StoryObj<typeof PlotParallelCoordinate>

export const LightTheme: Story = {
parameters: {
studyId: 1,
theme: lightTheme,
},
}

export const DarkTheme: Story = {
parameters: {
studyId: 1,
theme: darkTheme,
},
}

// TODO(c-bata): Add a story for multi objective study.
// export const MultiObjective: Story = {
// parameters: {
// ...
// },
// }
Loading

0 comments on commit 610b453

Please sign in to comment.