| | import { getColor, COLORS } from "./colors.mjs" |
| | import Plotly from "plotly.js-basic-dist-min" |
| | import _ from "lodash" |
| |
|
| | const DATA_FOLDER = "assets/data/plots" |
| |
|
| |
|
| | const LINE_SETTINGS = { |
| | width: 2.5, |
| | type: "scatter", |
| | mode: "lines", |
| | } |
| | const BAR_SETTINGS = { |
| | width: 0.5, |
| | type: "bar", |
| | opacity: 0.9, |
| | marker: { |
| | line: { |
| | width: 1.0 |
| | } |
| | } |
| | } |
| |
|
| | const METRIC_ID_TO_PRIORITY = { |
| | "agg_score": 0, |
| | "hellaswag/acc_norm": 1, |
| | "arc/acc_norm": 2, |
| | "mmlu/acc_norm": 3, |
| | "openbookqa/acc_norm": 4, |
| | "commonsense_qa/acc_norm": 5, |
| | "piqa/acc_norm": 6, |
| | "siqa/acc_norm": 7, |
| | "winogrande/acc_norm": 8, |
| |
|
| |
|
| | |
| | "lines_ended_with_punct": 0, |
| | "lines_chars": 1, |
| | "short_lines": 2, |
| | } |
| |
|
| | const TASK_ID_TO_NAME = { |
| | |
| | agg_score: "Aggregate Score", |
| | "commonsense_qa/acc_norm": "Commonsense QA", |
| | "hellaswag/acc_norm": "HellaSwag", |
| | "openbookqa/acc_norm": "OpenBook QA", |
| | "piqa/acc_norm": "PIQA", |
| | "siqa/acc_norm": "Social IQA", |
| | "winogrande/acc_norm": "WinoGrande", |
| | "arc/acc_norm": "ARC", |
| | "mmlu/acc_norm": "MMLU", |
| |
|
| | |
| | "lines_ended_with_punct": "Lines Ended With Punctuation", |
| | "lines_chars": "Lines Chars", |
| | "short_lines": "Short Lines", |
| | }; |
| |
|
| | const DATASET_ID_TO_NAME = { |
| | pii_removed: "Fineweb", |
| | allenai_c4_en: "C4", |
| | "tiiuae_falcon-refinedweb_data": "RefinedWeb", |
| | "red-pajama-v2_jsonl-deduplicated-extract": "RedPajamaV2", |
| | "dolma-sample": "Dolma1.6", |
| | dedup_minhash_independent_output: "Independent Dedup MinHash", |
| | "dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", |
| | "dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", |
| | "ind_minhash-CC-MAIN-2019-18": "Independent MinHash CC-MAIN-2019-18", |
| | "wet-extraction-2019-18": "WET Extraction 2019-18", |
| | "dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", |
| | "dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", |
| |
|
| | }; |
| |
|
| | const DEFAULT_SETTINGS = { |
| | slider: { |
| | max: 30, |
| | min: 0, |
| | default: 0, |
| | }, |
| | defaultMetric: "agg_score", |
| | type: "line" |
| | }; |
| |
|
| | const DEFAULT_LAYOUT = { |
| | font: { |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | title: { |
| | text: "Plot Title", |
| | font: { |
| | size: 19, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | }, |
| | xaxis: { |
| | title: { |
| | text: "Training tokens (billions)", |
| | font: { |
| | size: 15, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | }, |
| | tickfont: { |
| | size: 14, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | showgrid: false, |
| | mirror: true, |
| | ticks: "outside", |
| | showline: true, |
| | }, |
| | yaxis: { |
| | title: { |
| | text: "Agg Score", |
| | font: { |
| | size: 15, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | standoff: 10, |
| | }, |
| | showgrid: false, |
| | mirror: true, |
| | ticks: "outside", |
| | showline: true, |
| | tickfont: { |
| | size: 14, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | }, |
| | yaxis2: { |
| | title: { |
| | text: "Words Contamination", |
| | font: { |
| | size: 15, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | standoff: 10, |
| | }, |
| | tickfont: { |
| | size: 14, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | showgrid: false, |
| | ticks: "outside", |
| | }, |
| | legend: { |
| | orientation: "v", |
| | xanchor: "right", |
| | yanchor: "bottom", |
| | x: 1, |
| | y: 0, |
| | font: { |
| | size: 14, |
| | family: "apple-system, Arial, sans-serif", |
| | }, |
| | bgcolor: "rgba(0,0,0,0)", |
| | }, |
| | margin: { |
| | t: 30, |
| | b: 50, |
| | }, |
| | height: 400, |
| | }; |
| |
|
| | const getAutoRange = (traces) => { |
| | let minX = Math.min(...traces.flatMap((trace) => trace.x)); |
| | let maxX = Math.max(...traces.flatMap((trace) => trace.x)); |
| | return [minX * 0.95, maxX * 1.05]; |
| | }; |
| |
|
| | const getColorForTrace = (traceName, colorsMapping, index) => { |
| | |
| | const reusedColor = colorsMapping.get(traceName) |
| | if (reusedColor) { |
| | return reusedColor |
| | } |
| |
|
| | let color = getColor(index) |
| | while (colorsMapping.has(color) && index < COLORS.length) { |
| | index += 1 |
| | color = getColor(index) |
| | } |
| | colorsMapping.set(traceName, color) |
| | return color |
| | } |
| |
|
| |
|
| | const createAblationPlottingElements = ( |
| | plotElement, |
| | indexMapping, |
| | settings |
| | ) => { |
| | const plot = document.createElement("figure"); |
| | const controls = document.createElement("div"); |
| | plot.classList.add("plotly"); |
| | controls.classList.add("plotly_controls"); |
| | plotElement.appendChild(plot); |
| | plotElement.appendChild(controls); |
| |
|
| | const metricOptions = Object.keys(indexMapping).filter( |
| | (metric) => metric in TASK_ID_TO_NAME |
| | ); |
| | |
| | let dropdown = undefined |
| | if (metricOptions.length > 1) { |
| | const dropdownLabel = document.createElement("label"); |
| | dropdownLabel.textContent = "Metric:"; |
| | dropdown = document.createElement("select"); |
| | dropdown.innerHTML = metricOptions |
| | .sort((a, b) => (METRIC_ID_TO_PRIORITY[a] ?? 0) - (METRIC_ID_TO_PRIORITY[b] ?? 0)) |
| | .map( |
| | (option) => |
| | `<option value="${option}">${TASK_ID_TO_NAME[option]}</option>` |
| | ) |
| | .join(""); |
| | dropdown.value = settings.defaultMetric; |
| |
|
| | const dropdownContainer = document.createElement("div"); |
| | dropdownContainer.classList.add("plotly_input_container"); |
| | dropdownContainer.appendChild(dropdownLabel); |
| | dropdownContainer.appendChild(dropdown); |
| | controls.appendChild(dropdownContainer); |
| | } |
| |
|
| | let slider = undefined; |
| | if (settings.slider !== null) { |
| | const sliderLabel = document.createElement("label"); |
| | sliderLabel.textContent = "Rolling window:"; |
| | slider = document.createElement("input"); |
| | slider.type = "range"; |
| | slider.min = settings.slider.min; |
| | slider.max = settings.slider.max; |
| | slider.value = settings.slider.default; |
| |
|
| | |
| | const sliderValue = document.createElement("span"); |
| | sliderValue.textContent = slider.value; |
| | slider.addEventListener("input", () => { |
| | sliderValue.textContent = slider.value; |
| | }); |
| | const sliderInputContainer = document.createElement("div"); |
| | sliderInputContainer.classList.add("plotly_slider"); |
| | sliderInputContainer.appendChild(slider); |
| | sliderInputContainer.appendChild(sliderValue); |
| |
|
| | const sliderContainer = document.createElement("div"); |
| | sliderContainer.classList.add("plotly_input_container"); |
| |
|
| | sliderContainer.appendChild(sliderLabel); |
| | sliderContainer.appendChild(sliderInputContainer); |
| | controls.appendChild(sliderContainer); |
| | } |
| | let caption = undefined |
| | if (settings.caption) { |
| | caption = document.createElement("figcaption"); |
| | caption.classList.add("plotly_caption"); |
| | caption.textContent = settings.caption; |
| | plotElement.appendChild(caption); |
| | } |
| |
|
| | return { dropdown, slider, plot, caption }; |
| | }; |
| |
|
| | const rollingWindow = function (data, windowSize) { |
| | if (windowSize === 0) { |
| | return data; |
| | } |
| | const rollingData = []; |
| |
|
| | |
| | for (let i = windowSize; i < data.length; i++) { |
| | const windowStart = i - windowSize; |
| | const windowEnd = i; |
| | const windowData = data.slice(windowStart, windowEnd); |
| |
|
| | const windowAverage = |
| | windowData.reduce((acc, value) => acc + value, 0) / |
| | windowData.length; |
| | rollingData.push(windowAverage); |
| | } |
| |
|
| | return rollingData; |
| | }; |
| |
|
| | const createTraces = (data, settings, colorsMapping, sliderValue) => { |
| | if (!data) { |
| | return [] |
| | } |
| | const res = Array.from(Object.entries(data)).map(([key, traceData], index) => { |
| | const y = rollingWindow(traceData.y, sliderValue); |
| | const x = traceData.x.slice(0, y.length); |
| | const plotSettings = settings?.type === "bar" ? BAR_SETTINGS : LINE_SETTINGS; |
| | const traceColor = traceData.color ?? getColorForTrace(key, colorsMapping, index) |
| | const trace = _.merge({}, { |
| | x: x, |
| | y: y, |
| | name: traceData.label ?? DATASET_ID_TO_NAME[key] ?? key, |
| | marker: { |
| | color: traceColor, |
| | }, |
| | line: { |
| | color: traceColor, |
| | }, |
| | yaxis: traceData.yaxis ?? "y1" |
| | }, plotSettings); |
| | return trace |
| | }); |
| | return res |
| | } |
| |
|
| | export const init_ablation_plot = function () { |
| | const plotElements = document.querySelectorAll('[id^="plot-"]'); |
| | plotElements.forEach(async (plotElement) => { |
| | const plotName = plotElement.id.replace("plot-", ""); |
| | const indexData = await fetch(`${DATA_FOLDER}/${plotName}/index.json`).then( |
| | (response) => response.json() |
| | ); |
| | const settings = _.merge({}, DEFAULT_SETTINGS, indexData.settings); |
| | const indexMapping = indexData.files; |
| | const { dropdown, slider, plot } = createAblationPlottingElements( |
| | plotElement, |
| | indexMapping, |
| | settings |
| | ); |
| | plot.id = `graph-${plotName}`; |
| | if (dropdown !== undefined) { |
| | dropdown.addEventListener("change", () => updatePlot(dropdown, slider)); |
| | } |
| | let timeoutId; |
| | |
| | if (slider !== undefined) { |
| | slider.addEventListener("input", () => { |
| | clearTimeout(timeoutId); |
| | timeoutId = setTimeout(() => { |
| | updatePlot(dropdown, slider); |
| | }, 500); |
| | }); |
| | } |
| | |
| | Plotly.newPlot(plot, []); |
| |
|
| | |
| | const colorsMapping = new Map() |
| |
|
| | async function updatePlot(dropdown, slider) { |
| | const metricName = dropdown?.value ?? settings.defaultMetric; |
| | const sliderValue = parseInt(slider?.value ?? 0); |
| | const metricData = await fetch( |
| | `${DATA_FOLDER}/${plotName}/${indexMapping[metricName]["file"]}` |
| | ).then((response) => response.json()); |
| | const traces = (metricData?.traces ?? []).concat(createTraces(metricData.data, settings, colorsMapping, sliderValue)) |
| | const width = plot.parentElement.offsetWidth; |
| | const layout = _.merge( |
| | {}, |
| | DEFAULT_LAYOUT, |
| | { |
| | width: width, |
| | yaxis: { title: { text: TASK_ID_TO_NAME[metricName] } }, |
| | xaxis: { |
| | range: null |
| | }, |
| | }, |
| | metricData.layout |
| | ); |
| | Plotly.react(plot, traces, layout); |
| |
|
| | window.addEventListener("resize", () => { |
| | |
| | if (window.innerWidth < 768) { |
| | return; |
| | } |
| | Plotly.relayout(plot, { |
| | width: plot.parentElement.offsetWidth, |
| | }); |
| | }); |
| | } |
| |
|
| | |
| | updatePlot(dropdown, slider); |
| | }); |
| | }; |