import React, { useState, useRef, useEffect } from 'react';
import axios from 'axios';
import * as d3 from 'd3';
import { z } from "zod";
import { zodResponseFormat } from "openai/helpers/zod";

const API_KEY = process.env.REACT_APP_OPENAI_API_KEY;
const API_URL = 'https://api.openai.com/v1/chat/completions';
const ORG_ID = 'org-t8wbLMWCxhlBRBfMgSzxgV54';


const ChartSpec = z.object({
  type: z.string(),
  title: z.string(),
  description: z.string(),
  data: z.string(),
  width: z.number(),
  height: z.number(),
  margins: z.object({
    top: z.number(),
    right: z.number(),
    bottom: z.number(),
    left: z.number(),
  }).optional(),
  xAxis: z.object({
    label: z.string(),
    attribute: z.string(),
    scale: z.string(),
    tickFormat: z.string(),
    grid: z.boolean(),
  }).optional(),
  yAxis: z.object({
    label: z.string(),
    attribute: z.string(),
    scale: z.string(),
    tickFormat: z.string(),
    grid: z.boolean(),
  }).optional(),
  color: z.object({
    attribute: z.string(),
    scale: z.string(),
    scheme: z.string(),
    legend: z.object({
      title: z.string(),
      position: z.string(),
    }),
  }).optional(),
  size: z.object({
    attribute: z.string(),
    scale: z.string(),
    range: z.array(z.number()),
    legend: z.object({
      title: z.string(),
      position: z.string(),
    }),
  }).optional(),
  tooltip: z.object({
    content: z.array(z.object({
      label: z.string(),
      attribute: z.string(),
      format: z.string().optional(),
    })),
  }).optional(),
  interactivity: z.object({
    zoom: z.boolean(),
    pan: z.boolean(),
    highlight: z.object({
      attribute: z.string(),
      active: z.string(),
      inactive: z.string(),
    }),
  }).optional(),
  animations: z.object({
    duration: z.number(),
    ease: z.string(),
  }).optional(),
}).passthrough();

const CodeVisualizer = ({ data, columns }) => {
  const [input, setInput] = useState('');
  const [loading, setLoading] = useState(false);
  const [error, setError] = useState(null);
  const [chartSpec, setChartSpec] = useState(null);
  const chartRef = useRef(null);




const generateVisualization = async () => {
  setLoading(true);
  setError(null);
  setChartSpec(null);

  try {
    const response = await axios.post(
      API_URL,
      {
        model: "gpt-4o-2024-08-06",
        messages: [
          { role: "system", content: `You are an expert D3.js developer. Create beautiful, interactive, and informative visualizations using D3.js. The available data columns are: ${columns.join(', ')}. The data is an array of objects, each representing a member with these properties. Respond with a JSON object describing the chart specification that strictly adheres to the following schema:

          {
            type: string,
            title: string,
            description: string,
            data: string,
            width: number,
            height: number,
            margins: { top: number, right: number, bottom: number, left: number },
            xAxis: { label: string, attribute: string, scale: string, tickFormat: string, grid: boolean },
            yAxis: { label: string, attribute: string, scale: string, tickFormat: string, grid: boolean },
            color: { 
              attribute: string, 
              scale: string, 
              scheme: string, 
              legend: { title: string, position: string }
            },
            size: { 
              attribute: string, 
              scale: string, 
              range: [number, number], 
              legend: { title: string, position: string }
            },
            tooltip: { 
              content: [{ label: string, attribute: string, format: string (optional) }]
            },
            interactivity: { 
              zoom: boolean, 
              pan: boolean, 
              highlight: { attribute: string, active: string, inactive: string }
            },
            animations: { duration: number, ease: string }
          }

          Make the visualization extremely beautiful, informative, and interactive. Be creative and expressive with the design while adhering to the schema.` },
          { role: "user", content: `Create a D3.js visualization based on this request: ${input}. Use the 'data' parameter which contains the member data.` }
        ],
        response_format: { 
          type: "json_object" 
        },
        temperature: 0.7,
      },
      {
        headers: {
          'Authorization': `Bearer ${API_KEY}`,
          'Content-Type': 'application/json',
          'OpenAI-Organization': ORG_ID
        }
      }
    );

    console.log("Full API response:", JSON.stringify(response.data, null, 2));

    const content = response.data.choices[0].message.content;
    const parsedContent = JSON.parse(content);
    console.log("Parsed content:", parsedContent);

    const result = ChartSpec.safeParse(parsedContent);
    if (result.success) {
      setChartSpec(result.data);
    } else {
      console.error("Invalid chart spec:", result.error.errors);
      setError("Generated chart specification is invalid. Please try again.");
    }
  } catch (err) {
    if (err.response && err.response.data && err.response.data.error) {
      setError(`API Error: ${err.response.data.error.message}`);
    } else {
      setError(`Failed to generate visualization: ${err.message}`);
    }
    console.error('Full error:', err);
  } finally {
    setLoading(false);
  }
};



const createVisualization = (container, data, d3, spec) => {
  d3.select(container).selectAll("*").remove();

  const { width, height, margins } = spec;
  const innerWidth = width - margins.left - margins.right;
  const innerHeight = height - margins.top - margins.bottom;

  const svg = d3.select(container)
    .append("svg")
    .attr("width", width)
    .attr("height", height)
    .append("g")
    .attr("transform", `translate(${margins.left},${margins.top})`);

  const y = createScale(spec.yAxis.scale, [innerHeight, 0]);
  const color = createColorScale(spec.color.scale, spec.color.scheme);
 const size = createScale(spec.size.scale, [1, 5]);
    const x = spec.xAxis.scale === 'band' 
  ? d3.scaleBand().range([0, innerWidth]).padding(0.1)
  : createScale(spec.xAxis.scale, [0, innerWidth]);

  // Set domains for all scales
  x.domain(d3.extent(data, d => d[spec.xAxis.attribute]));
  y.domain([0, d3.max(data, d => d[spec.yAxis.attribute])]);
  color.domain(Array.from(new Set(data.map(d => d[spec.color.attribute]))));
  size.domain(d3.extent(data, d => d[spec.size.attribute]));

  const xAxis = d3.axisBottom(x).tickFormat(d3.format(spec.xAxis.tickFormat));
  const yAxis = d3.axisLeft(y).tickFormat(d3.format(spec.yAxis.tickFormat));

  svg.append("g")
    .attr("class", "x-axis")
    .attr("transform", `translate(0,${innerHeight})`)
    .call(xAxis);

  svg.append("g")
    .attr("class", "y-axis")
    .call(yAxis);

  addAxisLabel(svg, spec.xAxis.label, innerWidth / 2, innerHeight + margins.bottom - 10, "x-axis-label");
  addAxisLabel(svg, spec.yAxis.label, -innerHeight / 2, -margins.left + 20, "y-axis-label", -90);
  addChartTitle(svg, spec.title, innerWidth / 2, -margins.top / 2);

  const tooltip = createTooltip(container);

  createGenericChart(svg, data, x, y, color, size, spec, tooltip);

  if (spec.color.legend.title) {
    addColorLegend(svg, color, spec.color.legend, innerWidth);
  }
  if (spec.size.legend.title) {
    addSizeLegend(svg, size, spec.size.legend, innerWidth, color.domain().length * 20 + 40);
  }
};




const createScatterplot = (svg, data, x, y, color, size, spec, tooltip) => {
  svg.selectAll(".point")
    .data(data)
    .enter()
    .append("circle")
    .attr("class", "point")
    .attr("cx", d => x(+d[spec.xAxis.attribute]))
    .attr("cy", d => y(+d[spec.yAxis.attribute]))
    .attr("r", d => size(+d[spec.size.attribute]))
    .style("fill", d => color(d[spec.color.attribute]))
    .style("opacity", 0.7)
    .on("mouseover", (event, d) => showTooltip(event, d, tooltip, spec))
    .on("mouseout", () => hideTooltip(tooltip));
};





// Helper functions

const createScale = (scaleType, range) => {
  switch (scaleType.toLowerCase()) {
    case "linear": return d3.scaleLinear().range(range);
    case "log": return d3.scaleLog().range(range);
    case "sqrt": return d3.scaleSqrt().range(range);
    case "time": return d3.scaleTime().range(range);
    case "ordinal": return d3.scaleOrdinal().range(range);
    case "band": return d3.scaleBand().range(range).padding(0.1);
    default: return d3.scaleLinear().range(range);
  }
};

const createColorScale = (scaleType, scheme) => {
  switch (scaleType.toLowerCase()) {
    case "ordinal":
      return d3.scaleOrdinal(d3[scheme] || d3.schemeSet3);
    case "sequential":
      return d3.scaleSequential(d3[scheme] || d3.interpolateRainbow); 
    case "diverging":
      return d3.scaleDiverging(d3[scheme] || d3.interpolateRdYlBu);
    default:
      return d3.scaleOrdinal(d3.schemeSet3);
  }
};


const addAxisLabel = (svg, text, x, y, className, rotate = 0) => {
  svg.append("text")
    .attr("class", className)
    .attr("x", x)
    .attr("y", y)
    .attr("transform", `rotate(${rotate})`)
    .style("text-anchor", "middle")
    .text(text);
};

const addChartTitle = (svg, title, x, y) => {
  svg.append("text")
    .attr("class", "chart-title")
    .attr("x", x)
    .attr("y", y)
    .style("text-anchor", "middle")
    .style("font-size", "16px")
    .style("font-weight", "bold")
    .text(title);
};

const createTooltip = (container) => {
  return d3.select(container)
    .append("div")
    .attr("class", "tooltip")
    .style("opacity", 0)
    .style("position", "absolute")
    .style("background-color", "white")
    .style("border", "1px solid #ddd")
    .style("border-radius", "8px")
    .style("padding", "10px")
    .style("pointer-events", "none");
};



const createBarChart = (svg, data, x, y, color, spec, tooltip) => {
  // Group and sum the data if necessary
  const aggregatedData = d3.rollup(
    data,
    v => d3.sum(v, d => +d[spec.yAxis.attribute]),
    d => d[spec.xAxis.attribute]
  );

  const chartData = Array.from(aggregatedData, ([key, value]) => ({ key, value }));

  // Update scales
  x.domain(chartData.map(d => d.key));
  y.domain([0, d3.max(chartData, d => d.value)]);

  svg.selectAll(".bar")
    .data(chartData)
    .enter()
    .append("rect")
    .attr("class", "bar")
    .attr("x", d => x(d.key))
    .attr("width", x.bandwidth())
    .attr("y", d => y(d.value))
    .attr("height", d => y(0) - y(d.value))
    .style("fill", d => color(d.key))
    .on("mouseover", (event, d) => {
      tooltip.transition()
        .duration(200)
        .style("opacity", 0.9);
      tooltip.html(`
        <strong>${spec.xAxis.label}:</strong> ${d.key}<br>
        <strong>${spec.yAxis.label}:</strong> ${d.value}
      `)
        .style("left", (event.pageX + 10) + "px")
        .style("top", (event.pageY - 10) + "px");
    })
    .on("mouseout", () => {
      tooltip.transition()
        .duration(500)
        .style("opacity", 0);
    });

  // Update axes
  svg.select(".x-axis").call(d3.axisBottom(x));
  svg.select(".y-axis").call(d3.axisLeft(y));
};


const createHistogram = (svg, data, x, y, color, spec, tooltip) => {
  const xValue = d => +d[spec.xAxis.attribute];
  const histogram = d3.histogram()
    .value(xValue)
    .domain(x.domain())
    .thresholds(x.ticks(20));

  const bins = histogram(data);

  // Update y scale
  y.domain([0, d3.max(bins, d => d.length)]);

  svg.selectAll("rect")
    .data(bins)
    .enter()
    .append("rect")
    .attr("x", d => x(d.x0) + 1)
    .attr("width", d => Math.max(0, x(d.x1) - x(d.x0) - 1))
    .attr("y", d => y(d.length))
    .attr("height", d => y(0) - y(d.length))
    .style("fill", d => color(d.x0))
    .on("mouseover", (event, d) => {
      const tooltipContent = `
        <strong>${spec.xAxis.label}:</strong> ${d3.format(spec.xAxis.tickFormat)(d.x0)} - ${d3.format(spec.xAxis.tickFormat)(d.x1)}<br>
        <strong>Count:</strong> ${d.length}
      `;
      tooltip.transition().duration(200).style("opacity", 0.9);
      tooltip.html(tooltipContent)
        .style("left", (event.pageX + 10) + "px")
        .style("top", (event.pageY - 10) + "px");
    })
    .on("mouseout", () => tooltip.transition().duration(500).style("opacity", 0));

  // Update axes
  svg.select(".x-axis").call(d3.axisBottom(x));
  svg.select(".y-axis").call(d3.axisLeft(y));
};



const createGenericChart = (svg, data, x, y, color, size, spec, tooltip) => {
  switch (spec.type.toLowerCase()) {
    case 'histogram':
      createHistogram(svg, data, x, y, color, spec, tooltip);
      break;
    case 'bar':
      createBarChart(svg, data, x, y, color, spec, tooltip);
      break;
    case 'line':
      createLineChart(svg, data, x, y, color, spec, tooltip);
      break;
    case 'scatterplot':
    default:
      svg.selectAll(".data-point")
        .data(data)
        .enter()
        .append("circle")
        .attr("class", "data-point")
        .attr("cx", d => x(d[spec.xAxis.attribute]))
        .attr("cy", d => y(d[spec.yAxis.attribute]))
        .attr("r", d => size(d[spec.size.attribute]))
        .style("fill", d => color(d[spec.color.attribute]))
        .on("mouseover", (event, d) => showTooltip(event, d, tooltip, spec))
        .on("mouseout", () => hideTooltip(tooltip));
      break;
  }
};



const createLineChart = (svg, data, x, y, color, spec, tooltip) => {
  const nestedData = Array.from(
    d3.group(data, d => d[spec.color.attribute]),
    ([key, values]) => ({ key, values })
  );

  const line = d3.line()
    .x(d => x(+d[spec.xAxis.attribute]))
    .y(d => y(+d[spec.yAxis.attribute]));

  nestedData.forEach(group => {
    svg.append("path")
      .datum(group.values)
      .attr("class", "line")
      .attr("d", line)
      .style("fill", "none")
      .style("stroke", color(group.key))
      .style("stroke-width", 2);
  });

  svg.selectAll(".point")
    .data(data)
    .enter()
    .append("circle")
    .attr("class", "point")
    .attr("cx", d => x(+d[spec.xAxis.attribute]))
    .attr("cy", d => y(+d[spec.yAxis.attribute]))
    .attr("r", 4)
    .style("fill", d => color(d[spec.color.attribute]))
    .on("mouseover", (event, d) => showTooltip(event, d, tooltip, spec))
    .on("mouseout", () => hideTooltip(tooltip));
};

const showTooltip = (event, d, tooltip, spec) => {
  let tooltipContent = spec.tooltip.content.map(item => 
    `<strong>${item.label}:</strong> ${d3.format(item.format || "")(d[item.attribute])}`
  ).join("<br>");

  tooltip.transition()
    .duration(200)
    .style("opacity", 0.9);
  tooltip.html(tooltipContent)
    .style("left", (event.pageX + 10) + "px")
    .style("top", (event.pageY - 10) + "px");
};

const showHistogramTooltip = (event, d, tooltip, spec) => {
  const tooltipContent = `
    <strong>${spec.xAxis.label}:</strong> ${d3.format(spec.xAxis.tickFormat)(d.x0)} - ${d3.format(spec.xAxis.tickFormat)(d.x1)}<br>
    <strong>Count:</strong> ${d.length}
  `;
  tooltip.transition()
    .duration(200)
    .style("opacity", 0.9);
  tooltip.html(tooltipContent)
    .style("left", (event.pageX + 10) + "px")
    .style("top", (event.pageY - 10) + "px");
};

const showBarTooltip = (event, d, tooltip, spec) => {
  const tooltipContent = `
    <strong>${spec.xAxis.label}:</strong> ${d.key}<br>
    <strong>${spec.yAxis.label}:</strong> ${d3.format(spec.yAxis.tickFormat)(d.value)}
  `;
  tooltip.transition()
    .duration(200)
    .style("opacity", 0.9);
  tooltip.html(tooltipContent)
    .style("left", (event.pageX + 10) + "px")
    .style("top", (event.pageY - 10) + "px");
};

const hideTooltip = (tooltip) => {
  tooltip.transition()
    .duration(500)
    .style("opacity", 0);
};

const addColorLegend = (svg, color, legendSpec, xPosition) => {
  const colorLegend = svg.append("g")
    .attr("class", "color-legend")
    .attr("transform", `translate(${xPosition + 10}, 0)`);

  const colorLegendItems = colorLegend.selectAll(".color-legend-item")
    .data(color.domain())
    .enter()
    .append("g")
    .attr("class", "color-legend-item")
    .attr("transform", (d, i) => `translate(0, ${i * 20})`);

  colorLegendItems.append("rect")
    .attr("width", 10)
    .attr("height", 10)
    .style("fill", color);

  colorLegendItems.append("text")
    .attr("x", 15)
    .attr("y", 9)
    .text(d => d)
    .style("font-size", "12px");

  colorLegend.append("text")
    .attr("class", "legend-title")
    .attr("x", 0)
    .attr("y", -10)
    .text(legendSpec.title)
    .style("font-weight", "bold");
};

const addSizeLegend = (svg, size, legendSpec, xPosition, yPosition) => {
  const sizeLegend = svg.append("g")
    .attr("class", "size-legend")
    .attr("transform", `translate(${xPosition + 10}, ${yPosition})`);

  const sizeScale = d3.scaleSqrt()
    .domain(size.domain())
    .range([5, 15]);

  const sizeLegendItems = sizeLegend.selectAll(".size-legend-item")
    .data(sizeScale.ticks(3))
    .enter()
    .append("g")
    .attr("class", "size-legend-item")
    .attr("transform", (d, i) => `translate(0, ${i * 25})`);

  sizeLegendItems.append("circle")
    .attr("r", sizeScale)
    .style("fill", "none")
    .style("stroke", "black");

  sizeLegendItems.append("text")
    .attr("x", 20)
    .attr("y", d => -sizeScale(d))
    .text(d3.format(".1f"))
    .style("font-size", "12px");

  sizeLegend.append("text")
    .attr("class", "legend-title")
    .attr("x", 0)
    .attr("y", -10)
    .text(legendSpec.title)
    .style("font-weight", "bold");
};


const chartTypes = {
  scatterplot: createScatterplot,
  histogram: createHistogram,
  bar: createBarChart,
  line: createLineChart,
  // Add more chart types here as needed
};



useEffect(() => {
  let cleanup;
  if (chartSpec && chartRef.current) {
    d3.select(chartRef.current).selectAll("*").remove();
    try {
      cleanup = createVisualization(chartRef.current, data, d3, chartSpec);
      chartRef.current.style.display = 'flex';
    } catch (err) {
      setError(`Failed to render visualization: ${err.message}`);
      console.error('Render error:', err);
    }
  }

  return () => {
    if (cleanup) cleanup();
  };
}, [chartSpec, data]);




  return (
    <div className="code-visualizer" style={styles.container}>
      <h2 style={styles.title}>Visualization Studio (beta version demo)</h2>
      <div style={styles.inputSection}>
        <input
          type="text"
          value={input}
          onChange={(e) => setInput(e.target.value)}
          placeholder="Describe the visualization you want (e.g., 'Scatter plot of BMI vs age')"
          aria-label="Visualization description input"
          style={styles.input}
        />
        <button 
          onClick={generateVisualization} 
          disabled={loading || !input.trim()}
          aria-label="Generate visualization"
          style={styles.button}
        >
          {loading ? 'Generating...' : 'Generate'}
        </button>
      </div>
      {error && <p style={styles.error} role="alert">{error}</p>}
      {chartSpec && (
        <div style={{width: '100%', height: 'calc(100vh - 250px)', minHeight: '600px', overflow: 'auto'}}>
          <div ref={chartRef} style={{...styles.chartContainer, width: chartSpec.width, height: chartSpec.height}}></div>
        </div>
      )}
    </div>
  );
};


export default CodeVisualizer;


const styles = {
  container: {
    fontFamily: '"Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif',
    maxWidth: '1200px',
    margin: '40px auto',
    padding: '30px',
    backgroundColor: '#ffffff',
    borderRadius: '12px',
    boxShadow: '0 6px 12px rgba(0, 0, 0, 0.1)',
  },
  title: {
    color: '#2c3e50',
    textAlign: 'center',
    marginBottom: '30px',
    fontSize: '32px',
    fontWeight: '700',
  },
  inputSection: {
    display: 'flex',
    marginBottom: '30px',
    gap: '10px',
  },
  input: {
    flex: '1',
    padding: '15px 20px',
    fontSize: '16px',
    border: '2px solid #dce4ec',
    borderRadius: '8px',
    transition: 'all 0.3s ease',
  },
  button: {
    padding: '15px 30px',
    fontSize: '16px',
    fontWeight: '600',
    backgroundColor: '#3498db',
    color: 'white',
    border: 'none',
    borderRadius: '8px',
    cursor: 'pointer',
    transition: 'all 0.3s ease',
  },
  error: {
    color: '#e74c3c',
    textAlign: 'center',
    marginBottom: '20px',
    padding: '15px',
    backgroundColor: '#fadbd8',
    borderRadius: '8px',
    fontSize: '16px',
    fontWeight: '500',
  },
  chartContainer: {
    backgroundColor: '#f8f9fa',
    borderRadius: '8px',
    padding: '30px',
    boxShadow: 'inset 0 2px 4px rgba(0, 0, 0, 0.1)',
    minHeight: '500px',
    width: '100%',
    height: '600px',
  },
};

