import React, { useEffect, useRef } from "react";
import * as d3 from "d3";
import { styled } from '@mui/material/styles';
import Tooltip, { tooltipClasses } from '@mui/material/Tooltip';
import Typography from '@mui/material/Typography';
import IconButton from '@mui/material/IconButton';
import HelpOutlineIcon from '@mui/icons-material/HelpOutline';
import useWalks from "@component/stores/walks";
import { CHART_COLORS } from "./colors";
/**
* Styled component for custom HTML tooltip.
*/
const HtmlTooltip = styled(({ className, ...props }) => (
<Tooltip {...props} classes={{ popper: className }} />
))(({ theme }) => ({
[`& .${tooltipClasses.tooltip}`]: {
backgroundColor: '#f5f5f9',
color: 'rgba(0, 0, 0, 0.87)',
maxWidth: 220,
fontSize: theme.typography.pxToRem(12),
border: '1px solid #dadde9',
},
}));
/**
* Generates the scatterplot chart for umap projection.
* @param {Object} ref - The reference to the chart container holding the svgs.
* @param {Array} walks - The array of walks data.
* @param {Array} selectedWalks - The array of selected walks to color the selection made by the user through brushing or clicking.
* @param {Function} setSelectedWalks - The function to set the selected walks to save for other charts.
*/
const generateScatterplot = (ref, walks, selectedWalks, setSelectedWalks) => {
const primary = CHART_COLORS.primary;
const secondary = CHART_COLORS.secondary;
const margin = { top: 15, right: 60, bottom: 30, left: 40 };
const width = 600 - margin.left - margin.right;
const height = 400 - margin.top - margin.bottom;
d3.select(ref.current).selectAll("svg").remove();
const minX = d3.min(walks, d => +d['umap'][0]);
const maxX = d3.max(walks, d => +d['umap'][0]);
const minY = d3.min(walks, d => +d['umap'][1]);
const maxY = d3.max(walks, d => +d['umap'][1]);
// create scales for the x and y axes
const xScale = d3
.scaleLinear()
.domain([minX-(maxX-minX)*0.1, maxX+(maxX-minX)*0.1])
.range([0, width]);
const yScale = d3
.scaleLinear()
.domain([minY-(maxY-minY)*0.1, maxY+(maxY-minY)*0.1])
.range([height, 0]);
// create the x and y axes
const xAxis = d3.axisBottom(xScale);
const yAxis = d3.axisLeft(yScale);
// create the scatterplot
const svg = d3
.select(ref.current)
.append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
const g = svg
.append("g")
.attr("transform", `translate(${margin.left},${margin.top})`);
g.append("g").attr("transform", `translate(0,${height})`).call(xAxis);
g.append("g").call(yAxis);
/**
* Click event handler for scatterplot circles.
* @param {Object} event - The click event object.
* @param {Object} d - The data associated with the clicked circle.
*/
const handleClick = (event, d) => {
const newSelectedWalks = [d.walk];
setSelectedWalks(newSelectedWalks);
svg
.selectAll('.scatterplot-circle')
.attr('fill', d => (newSelectedWalks.includes(d.walk) ? primary : secondary));
d3.select(event.currentTarget)
.attr('fill', primary)
}
const brush = d3.brush()
.extent([[0, 0], [width, height]])
.on("end", function handleBrush(event) {
if (!event.selection) {
setSelectedWalks([]);
} else {
const [[x1, y1], [x2, y2]] = event.selection;
const selected = walks.filter(d => {
const dx = xScale(d['umap'][0]), dy = yScale(d['umap'][1]);
return x1 <= dx && dx <= x2 && y1 <= dy && dy <= y2;
}).map(d => d.walk);
setSelectedWalks(selected);
}
// update the fill of circles
g.selectAll('.scatterplot-circle')
.attr('fill', d => selectedWalks.includes(d.walk) ? primary : secondary);
});
g.append("g")
.attr("class", "brush")
.call(brush);
// add circles after brush to still enable clicking event
g
.selectAll("circle")
.data(walks)
.enter()
.append("circle")
.attr("cx", (d) => xScale(+d['umap'][0]))
.attr("cy", (d) => yScale(+d['umap'][1]))
.attr("r", 4)
.attr("fill", (d) => (selectedWalks.includes(d.walk) ? primary : secondary)) // Change fill color based on condition
.on("click", handleClick)
.attr("class", "scatterplot-circle");
};
/**
* UMAP Scatterplot component.
* Displays the UMAP scatterplot chart.
* @returns {JSX.Element} UMAP scatterplot component.
*/
const Scatterplot = () => {
const chartRef = useRef();
const walks = useWalks(state => state.walks);
const selectedWalks = useWalks(state => state.selectedWalks);
const setSelectedWalks = useWalks(state => state.setSelectedWalks);
useEffect(() => {
if (walks) {
generateScatterplot(chartRef, walks, selectedWalks, setSelectedWalks);
}
}, [walks, selectedWalks]);
return (
<>
<div style={{display: 'flex', justifyContent: 'center', alignItems: 'center'}}>
<Typography variant="h6" component="div" style={{flexGrow: 1, textAlign: 'center'}}>Umap</Typography>
<HtmlTooltip
title={
<>
<Typography color="primary">UMAP Plot</Typography>
{/* your explanation goes here */}
<p>
The UMap-projection acts as a window for selecting walks in terms of similarity in their activations.
This representation of the walks is decoupled from the information that regression provides, and enables cross-examination of walks in two different representation.
</p>
<br/>
<p><b>Click:</b> to select a datapoint</p>
<p><b>Brush:</b> to select multiple datapoints</p>
</>
}
>
<IconButton>
<HelpOutlineIcon/>
</IconButton>
</HtmlTooltip>
</div>
<svg
viewBox={"0 0 " + 600 + " " + 400}
ref={chartRef}
/>
</>
);
};
export default Scatterplot;