import { events } from '../events'
import { useEvent } from '../framework'
import { isMobile } from '../utils'

/**
 * progress: goes from 1 to -1.
 * 1: element is below or has just entered viewport from below
 * 0: middle of element is in the middle of viewport
 * -1: element was just scrolled by, or is above viewport
 */

const lerp = (f0, f1, t) => (1 - t) * f0 + t * f1

function resetPosition(element) {
  element.style.transform = 'none'
}

function getPosition(element) {
  const { top, height } = element.getBoundingClientRect()
  return top + window.scrollY <= window.innerHeight ? 0 : Math.round(top + window.scrollY + height / 2)
}

function getHeight(element) {
  const { height } = element.getBoundingClientRect()
  return Math.round(height)
}

export default ref => {
  if (!ref.parallax) return

  const state = {
    elementsInView: new Set(),
    running: false,
    lerpPosition: window.scrollY,
  }

  const callback = entries => {
    entries.forEach(entry => {
      if (entry.isIntersecting) {
        state.elementsInView.add(entry.target)
      } else {
        state.elementsInView.delete(entry.target)
      }
    })
  }

  const observer = new IntersectionObserver(callback, {
    rootMargin: '200px 0px 200px 0px',
  })

  ref.parallax.forEach(element => {
    element.dataset.position = getPosition(element)
    element.dataset.height = getHeight(element)
    observer.observe(element)
  })

  function setParallax(element) {
    const position = parseInt(element.dataset.position)
    const height = parseInt(element.dataset.height)

    const targets = [...element.querySelectorAll('[data-parallax-target]')]

    targets.forEach(target => {
      const amount =
        target.dataset.parallaxDesktop !== undefined && isMobile()
          ? 0
          : parseInt(target.dataset.parallaxAmount || 5)

      const screenMid = position === 0 ? 0 : window.innerHeight / 2

      const progress =
        (position - state.lerpPosition - screenMid) / (screenMid + height / 2)

      target.style.transform = `translateY(${progress * amount}vmin)`
    })
  }

  function raf() {
    if (state.running) {
      requestAnimationFrame(raf)
    }

    state.lerpPosition = lerp(state.lerpPosition, window.scrollY, 0.5)

    state.elementsInView.forEach(setParallax)

    // pause when target is reached
    if (Math.round(state.lerpPosition) === window.scrollY) {
      state.running = false
    }
  }

  // set initial position
  ref.parallax.forEach(setParallax)

  useEvent.listen(events.window.scroll, () => {
    if (state.running) return

    state.running = true
    raf()
  })
  useEvent.listen(events.window.resize, () => {
    ref.parallax.forEach(element => {
      resetPosition(element)
      element.dataset.position = getPosition(element)
      element.dataset.height = getHeight(element)
    })

    ref.parallax.forEach(setParallax)
  })

  return () => {
    state.elementsInView.clear()
    observer.disconnect()
  }
}
