import json
import time
import random
import re
import os
import atexit
import logging
import csv

import undetected_chromedriver as uc
from bs4 import BeautifulSoup
from selenium.webdriver.common.by import By
from selenium.webdriver.common.action_chains import ActionChains
from selenium.common.exceptions import WebDriverException

# ----------------------------
# CONFIG
# ----------------------------
OUTPUT_FILE = "WayfairVideos.json"
PARTNER_URLS_FILE = os.path.join(os.path.dirname(__file__), "partnerUrls.csv")
HEADLESS = False
NAV_TIMEOUT_SEC = 45
MAX_RETRIES = 3
SAVE_PDP_HTML = True
ROTATE_CONTEXT_EVERY = 20
PDP_HTML_DIR = "pdp_html/wf"
VPN_ENABLED = os.environ.get("WG_ENABLED") == "1"
WG_CONFIG_DIR = os.environ.get("WG_CONFIG_DIR", "/etc/wireguard")
WG_COOLDOWN_SEC = float(os.environ.get("WG_COOLDOWN_SEC", "2.0"))
VPN_ROTATOR = None

def load_partner_urls(csv_path: str, column_name: str) -> list[tuple[str, str]]:
    items = []
    with open(csv_path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            sku = (row.get("SKU") or "").strip()
            url = (row.get(column_name) or "").strip()
            if not url or url.upper() == "NULL":
                continue
            items.append((sku, url))
    return items

MANIFEST_RE = re.compile(r"\.m3u8(\?|$)", re.I)
FILE_RE = re.compile(r"\.(mp4|webm)(\?|$)", re.I)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

try:
    from wg_vpn import (
        WireGuardRotator,
        parse_document_status_from_performance_logs,
        should_rotate_on_status,
        is_dns_error_html,
    )
except Exception:
    WireGuardRotator = None
    parse_document_status_from_performance_logs = None
    should_rotate_on_status = None
    is_dns_error_html = None


def ensure_driver_ready(driver, max_attempts: int = 4):
    for _ in range(max_attempts):
        try:
            _ = driver.current_url
            driver.get("about:blank")
            driver.execute_cdp_cmd("Network.enable", {})
            time.sleep(0.5)
            return driver
        except Exception:
            try:
                driver.quit()
            except Exception:
                pass
            time.sleep(1.5)
            driver = get_driver()
    return driver

def get_driver():
    options = uc.ChromeOptions()
    options.add_argument("--no-sandbox")
    options.add_argument("--disable-blink-features=AutomationControlled")
    # options.add_argument("--disable-gpu")
    options.add_argument("--start-maximized")
    options.add_argument("--disable-extensions")
    options.add_argument("--disable-infobars")
    options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
    # if HEADLESS:
    #     options.add_argument("--headless=new")
    driver = uc.Chrome(options=options, use_subprocess=False)
    time.sleep(3)
    return driver


def extract_json_ld_videos(html: str) -> list[dict]:
    soup = BeautifulSoup(html, "html.parser")
    videos = []
    for script in soup.find_all("script", attrs={"type": "application/ld+json"}):
        raw = script.string or script.get_text(strip=True)
        if not raw:
            continue
        try:
            data = json.loads(raw)
        except Exception:
            continue

        nodes = []
        if isinstance(data, dict):
            if "@graph" in data and isinstance(data["@graph"], list):
                nodes = data["@graph"]
            else:
                nodes = [data]
        elif isinstance(data, list):
            nodes = data

        for node in nodes:
            if not isinstance(node, dict):
                continue
            if node.get("@type") != "VideoObject":
                continue
            url = node.get("contentUrl") or node.get("url")
            if not isinstance(url, str):
                continue
            videos.append({
                "url": url,
                "name": node.get("name"),
                "description": node.get("description"),
                "date_modified": node.get("dateModified"),
            })
    return videos


def collect_video_metadata_from_html(html: str) -> dict:
    soup = BeautifulSoup(html, "html.parser")
    meta = {}
    for video in soup.find_all("video"):
        poster = video.get("poster")
        for source in video.find_all("source"):
            src = source.get("src")
            if not isinstance(src, str):
                continue
            meta[src] = {
                "description": None,
                "date_modified": None,
                "poster": poster if isinstance(poster, str) else None,
            }
    return meta


def collect_media_urls_from_html(html: str) -> set[str]:
    urls = set()
    for match in re.findall(r"https?://[^\s'\"]+", html):
        if MANIFEST_RE.search(match) or FILE_RE.search(match):
            urls.add(match)
    return urls


def choose_best_video_url(urls: set[str]) -> dict:
    m3u8 = [u for u in urls if MANIFEST_RE.search(u)]
    mp4 = [u for u in urls if FILE_RE.search(u)]

    if m3u8:
        best = sorted(m3u8, key=len, reverse=True)[0]
        return {"type": "m3u8", "url": best}
    if mp4:
        best = sorted(mp4, key=len, reverse=True)[0]
        return {"type": "mp4/webm", "url": best}
    return {"type": "none", "url": None}

def handle_press_and_hold_captcha(
    driver,
    max_attempts: int = 3,
    hold_seconds: float = 4,
    wait_timeout_sec: float = 30.0,
) -> bool:
    def captcha_present(page_html: str) -> bool:
        return "Press &amp; Hold to confirm" in page_html or "Press & Hold to confirm" in page_html

    def find_press_and_hold():
        return driver.find_elements(
            By.XPATH,
            "//*[normalize-space()='Press & Hold' or contains(normalize-space(), 'Press & Hold')]",
        )

    def find_press_and_hold_button():
        return driver.find_elements(
            By.XPATH,
            "//*[@role='button' and contains(@aria-label, 'Press & Hold')]",
        )

    def try_click_and_hold_on_host():
        hosts = driver.find_elements(By.CSS_SELECTOR, "#px-captcha")
        if not hosts:
            return False
        for host in hosts:
            try:
                if not host.is_displayed():
                    continue
                size = host.size or {}
                width = size.get("width", 0)
                height = size.get("height", 0)
                if width <= 0 or height <= 0:
                    continue
                offset_x = max(5, int(width * 0.5))
                offset_y = 1
                try:
                    origin = host.location_once_scrolled_into_view or {}
                    target_x = origin.get("x", 0) + offset_x
                    target_y = origin.get("y", 0) + offset_y
                    driver.execute_script(
                        """
                        const [x, y] = arguments;
                        const dot = document.createElement('div');
                        dot.id = '__px_debug_dot__';
                        dot.style.position = 'absolute';
                        dot.style.left = x + 'px';
                        dot.style.top = y + 'px';
                        dot.style.width = '10px';
                        dot.style.height = '10px';
                        dot.style.borderRadius = '50%';
                        dot.style.background = 'red';
                        dot.style.zIndex = '2147483647';
                        document.body.appendChild(dot);
                        setTimeout(() => dot.remove(), 3000);
                        """,
                        target_x,
                        target_y,
                    )
                except WebDriverException:
                    pass
                action = ActionChains(driver)
                # action.move_to_element_with_offset(host, offset_x, offset_y).click_and_hold(host).perform()
                action.click_and_hold(host).perform()
                time.sleep(10)
                ActionChains(driver).release(host).perform()
                time.sleep(0.2)
                return True
            except WebDriverException:
                continue
        return False

    for attempt in range(1, max_attempts + 1):
        try:
            html = driver.page_source
        except WebDriverException:
            time.sleep(1)
            continue

        if not captcha_present(html):
            logging.info("Captcha not detected; continuing.")
            return False

        logging.info("Captcha detected (attempt %s/%s).", attempt, max_attempts)
        # try:
        #     candidates = find_press_and_hold_button()
        #     if not candidates:
        #         candidates = find_press_and_hold()
        #     logging.info("Top-level 'Press & Hold' candidates: %s", len(candidates))
        #     for el in candidates:
        #         if el.is_displayed():
        #             logging.info("Top-level Press & Hold visible; attempting click-and-hold.")
        #             ActionChains(driver).click_and_hold(el).perform()
        #             time.sleep(4)
        #             break
        # except WebDriverException:
        #     logging.info("Top-level click-and-hold failed due to WebDriverException.")
        #     pass


        end_time = time.time() + wait_timeout_sec
        while time.time() < end_time:
            try:
                html = driver.page_source
            except WebDriverException:
                time.sleep(1)
                continue
            if not captcha_present(html):
                logging.info("Captcha cleared.")
                return True
            time.sleep(1.5)
            if try_click_and_hold_on_host():
                logging.info("Click-and-hold attempted on #px-captcha host.")
                time.sleep(2)
        logging.info("Captcha still present after waiting %ss.", wait_timeout_sec)
    raise RuntimeError("Captcha still present after press-and-hold attempts")


def capture_video_for_url(driver, url: str, sku: str | None = None) -> dict:
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            logging.info(f"Open {url} (attempt {attempt}/{MAX_RETRIES})")
            driver.get(url)
            time.sleep(5)
            if VPN_ENABLED and VPN_ROTATOR is not None and parse_document_status_from_performance_logs is not None:
                try:
                    status = parse_document_status_from_performance_logs(driver.get_log("performance"), url)
                    if should_rotate_on_status is not None and should_rotate_on_status(status):
                        VPN_ROTATOR.rotate(f"HTTP {status} for {url}")
                        continue
                except Exception:
                    pass
            if VPN_ENABLED and VPN_ROTATOR is not None and is_dns_error_html is not None:
                try:
                    if is_dns_error_html(driver.page_source):
                        VPN_ROTATOR.rotate(f"DNS error for {url}")
                        continue
                except Exception:
                    pass

            captcha_cleared = handle_press_and_hold_captcha(driver)
            if captcha_cleared:
                time.sleep(2)

            html = driver.page_source
            if SAVE_PDP_HTML:
                os.makedirs(PDP_HTML_DIR, exist_ok=True)
                file_key = sku or re.sub(r"\W+", "_", url) or "pdp"
                html_path = os.path.join(PDP_HTML_DIR, f"{file_key}.html")
                with open(html_path, "w", encoding="utf-8") as f:
                    f.write(html)
                logging.info(f"Saved PDP HTML to {html_path}")

            captured = collect_media_urls_from_html(html)
            ld_videos = extract_json_ld_videos(html)
            html_video_meta = collect_video_metadata_from_html(html)

            for item in ld_videos:
                url_val = item.get("url")
                if isinstance(url_val, str):
                    captured.add(url_val)
            captured |= set(html_video_meta.keys())

            meta_map = {}
            for item in ld_videos:
                url_val = item.get("url")
                if not isinstance(url_val, str):
                    continue
                meta_map[url_val] = {
                    "description": item.get("description"),
                    "date_modified": item.get("date_modified"),
                    "poster": None,
                }
            for url_val, meta in html_video_meta.items():
                if url_val in meta_map:
                    if meta.get("poster"):
                        meta_map[url_val]["poster"] = meta.get("poster")
                else:
                    meta_map[url_val] = meta

            videos = []
            for url_val in sorted(captured):
                file_name = None
                try:
                    file_name = os.path.basename(url_val.split("?", 1)[0])
                    if file_name:
                        file_name = re.sub(r"%([0-9A-Fa-f]{2})", lambda m: bytes.fromhex(m.group(1)).decode("utf-8", "ignore"), file_name)
                except Exception:
                    file_name = None
                meta = meta_map.get(url_val, {})
                videos.append({
                    "url": url_val,
                    "type": "m3u8" if MANIFEST_RE.search(url_val) else "mp4/webm" if FILE_RE.search(url_val) else None,
                    "description": meta.get("description"),
                    "date_modified": meta.get("date_modified"),
                    "poster": meta.get("poster"),
                    "file_name": file_name,
                })

            return {
                "videos": videos,
                "all_urls": sorted(captured),
            }
        except Exception as e:
            logging.warning(f"Error: {e}")
            time.sleep(2 + random.uniform(0.8, 2.5))
    return {"videos": [], "all_urls": []}


def main():
    global VPN_ROTATOR
    if VPN_ENABLED and WireGuardRotator is not None:
        VPN_ROTATOR = WireGuardRotator(WG_CONFIG_DIR, cooldown_sec=WG_COOLDOWN_SEC)
        VPN_ROTATOR.ensure_up()
        atexit.register(VPN_ROTATOR.shutdown)

    pairs = load_partner_urls(PARTNER_URLS_FILE, "Wayfair URL")
    urls = [p[1] for p in pairs]
    skus = [p[0] for p in pairs]
    out = []
    if not urls:
        logging.warning("No URLs loaded")
        return

    driver = ensure_driver_ready(get_driver())
    try:
        for i, url in enumerate(urls):
            if i and i % ROTATE_CONTEXT_EVERY == 0:
                try:
                    driver.quit()
                except Exception:
                    pass
                time.sleep(2)
                driver = ensure_driver_ready(get_driver())
            info = capture_video_for_url(driver, url, skus[i] if i < len(skus) else None)
            row = {
                "SKU": skus[i] if i < len(skus) else None,
                "page_url": url,
                "videos": info["videos"],
                "all_captured_urls": info["all_urls"],
            }
            out.append(row)
            logging.info(f"Processed {i+1}/{len(urls)} -> {row}")
            time.sleep(3 + random.uniform(2.0, 4.5))
    finally:
        try:
            driver.quit()
        except Exception:
            pass

    output_path = os.environ.get("OUTPUT_PATH") or OUTPUT_FILE
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(out, f, indent=2, ensure_ascii=False)

    print(json.dumps(out, indent=2, ensure_ascii=False))
    logging.info(f"Saved -> {output_path}")


if __name__ == "__main__":
    main()
