import * as d3 from 'd3';

export const CohortGridSVG = (data, {
  gridWidth = 1100,
  gridHeight = 800,
  headerRef = {},
  periodPrefix = 'M',
  retType = 'logo',
  leftMargin = 130,
  maxCohorts = 26
} = {}) => {
  const allCohorts = Array.from(new d3.InternSet(data.map(d => d.cohortDate))).sort(d3.descending);
  const visibleCohorts = new d3.InternMap(allCohorts.slice(0, maxCohorts + 1).map(d => [d, true]));
  const visibleData = data.filter(d => visibleCohorts.has(d.cohortDate));
  const startValues = new d3.InternMap(visibleData.filter(d => d.periodNumber === 0).map(({cohortDate, value}) => [cohortDate, value]))
  const retentionCohorts = visibleData.filter(d => d.periodNumber >= 0);
  const cohortDates = Array.from(new d3.InternSet(retentionCohorts.map(d => d.cohortDate))).sort(d3.ascending);
  const periodNumbers = Array.from(new Set(retentionCohorts.map(d => d.periodNumber))).sort(d3.ascending);

  const margin = {top: 30, right: 10, bottom: 0, left: leftMargin};
  gridHeight = visibleCohorts.size * 25;
  gridWidth = periodNumbers.length * 45;

  const x = d3.scaleBand()
    .domain(periodNumbers)
    .rangeRound([margin.left, gridWidth - margin.right]);

  const y = d3.scaleBand()
    .domain(cohortDates)
    .rangeRound([margin.top, gridHeight - margin.bottom]);
  
  const color = d3.scaleSequential(d3.interpolateYlGnBu)
    .domain([0, d3.max(retentionCohorts, d => d.percentage)]);
  
  const label = d => !isNaN(d.percentage) ? d3.format('.0%')(d.percentage) : '';
  
  const div = d3.create('div')
    .style('overflow-x', 'auto')
    .style('font-variant-numeric', 'tabular-nums');
  
  const svg = div.append('svg')
    .attr('viewBox', [0, 0, gridWidth, gridHeight])
    .attr('width', gridWidth);
  
  const element = div.node();
  element.value = null;
  
  const g = svg.append('g')
    .attr('shape-rendering', 'crispEdges')
    .style('cursor', 'default');

  const row = g.selectAll('.row')
    .data(d3.groups(retentionCohorts, d => d.cohortDate))
    .join('g')
    .attr('class', 'row')
    .attr('transform', ([cohortDate, _]) => `translate(0,${y(cohortDate)})`);
  const cell = row.selectAll('.cell')
    .data(([ _, value]) => value)
    .join('g')
    .attr('class', 'cell')
    .attr('transform', d => `translate(${x(d.periodNumber)},0)`);
  
  cell.append('rect')
    .attr('fill', d => !isNaN(d.percentage) ? color(d.percentage) : 'white')
    .attr('width', x.bandwidth())
    .attr('height', y.bandwidth());
  
  cell.append('text')
    .text(label)
    .attr('fill', d => d3.lab(color(d.percentage)).l < 55 ? 'white' : 'black')
    .attr('x', x.bandwidth() - 5)
    .attr('y', y.bandwidth() / 2)
    .attr('text-anchor', 'end')
    .attr('dy', '0.35em')
    .attr('font-size', '10px')
    .attr('font-family', 'var(--sans-serif)');

  const axisTop = d3.axisTop(x)
    .tickFormat(periodNumber => `${periodPrefix}${periodNumber}`);

  svg.append('g')
    .attr('transform', `translate(0,${margin.top + 10})`)
    .call(axisTop)
    .call(g => g.selectAll('.domain, .tick line').remove())
    .call(g => g.selectAll('text').attr('font-family', 'var(--sans-serif)'));
  
  const rowLabel = row.append('g')
    .attr('font-size', '10px')
    .attr('font-family', 'var(--sans-serif)');
  
  rowLabel
    .append('text')
    .text(([cohortDate, _]) => headerRef[cohortDate])
    .attr('x', 2)
    .attr('y', y.bandwidth() / 2)
    .attr('dy', '0.35em')

  rowLabel
    .append('text')
    .text(([cohortDate, _]) => d3.format(',.0d')(startValues.get(cohortDate)))
    .attr('x', margin.left - 7)
    .attr('y', y.bandwidth() / 2)
    .attr('text-anchor', 'end')
    .attr('dy', '0.35em');
  
  const headers = svg.append('g')
    .attr('transform', `translate(0,${margin.top})`)
    .attr('font-size', 10)
    .attr('font-family', 'var(--sans-serif)');
  
  headers
    .append('text')
    .text(retType === 'logo' ? 'Start Logos' : 'Start ARR')
    .attr('x', margin.left - 7)
    .attr('y', -3)
    .attr('font-size', 10)
    .attr('font-family', 'var(--sans-serif)')
    .attr('text-anchor', 'end')
    .attr('dy', '0.35em');
  
  return element;
}