import json
import time
import random
import logging
import re
import os
import csv
import atexit
from urllib.parse import urlparse

import pandas as pd
import mysql.connector
import undetected_chromedriver as uc
from selenium.common.exceptions import WebDriverException
import tempfile
from selenium.webdriver.common.by import By

# ----------------------------
# CONFIG
# ----------------------------
db_config = {
    "host": "10.8.0.1",  # Адреса сервера бази даних
    "user": "integration",       # Ім'я користувача
    "password": "?Q8/{lVK2N08Y<b>k",  # Пароль
    "database": "Salsify"  # Назва бази даних
}

QUERY = """
SELECT SL.SKU, SHDUS.`Home Depot ID` as 'LINK'
FROM  Salsify.MainData SL
LEFT JOIN Salsify.Homedepot SHDUS ON SL.SKU = SHDUS.SKU
WHERE (SHDUS.SKU LIKE '%' AND  (SL.Status)='Active' AND SHDUS.`Home Depot ID` IS NOT NULL) OR  (SHDUS.SKU LIKE '%' AND (SL.Status)='Liquidation' AND SHDUS.`Home Depot ID` IS NOT NULL)
"""

OUTPUT_FILE = "HomedepotVideosRetry.json"
PARTNER_URLS_FILE = os.path.join(os.path.dirname(__file__), "partnerUrls.csv")
CHROMEDRIVER = "/usr/local/bin/chromedriver"

HEADLESS = False
NAV_TIMEOUT_MS = 45_000
MAX_RETRIES = 3
ROTATE_CONTEXT_EVERY = 15
SAVE_PDP_HTML = True
PDP_HTML_DIR = "pdp_html/hd/"
USE_UNDETECTED_CHROME = True
UC_HEADLESS = False
VPN_ENABLED = True
WG_CONFIG_DIR = os.environ.get("WG_CONFIG_DIR", "/etc/wireguard")
WG_COOLDOWN_SEC = float(os.environ.get("WG_COOLDOWN_SEC", "2.0"))
VPN_ROTATOR = None

# capture manifests and direct files
MANIFEST_RE = re.compile(r"\.m3u8(\?|$)", re.I)
FILE_RE = re.compile(r"\.(mp4|webm)(\?|$)", re.I)

# optional: capture segments too (usually you DON'T need them)
CAPTURE_SEGMENTS = False
SEGMENT_RE = re.compile(r"\.(m4s|ts)(\?|$)", re.I)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

try:
    from wg_vpn import (
        WireGuardRotator,
        parse_document_status_from_performance_logs,
        should_rotate_on_status,
        is_dns_error_html,
    )
except Exception:
    WireGuardRotator = None
    parse_document_status_from_performance_logs = None
    should_rotate_on_status = None
    is_dns_error_html = None


def load_partner_urls(csv_path: str, column_name: str) -> list[tuple[str, str]]:
    items = []
    with open(csv_path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            sku = (row.get("SKU") or "").strip()
            url = (row.get(column_name) or "").strip()
            if not url or url.upper() == "NULL":
                continue
            items.append((sku, url))
    return items


def extract_apollo_state(html: str) -> dict:
    marker = "window.__APOLLO_STATE__="
    idx = html.find(marker)
    if idx == -1:
        marker = "window['__APOLLO_STATE__'] ="
        idx = html.find(marker)
    if idx == -1:
        return {}
    start = html.find("{", idx)
    if start == -1:
        return {}
    depth = 0
    in_string = False
    escape = False
    end = None
    for i in range(start, len(html)):
        ch = html[i]
        if in_string:
            if escape:
                escape = False
            elif ch == "\\":
                escape = True
            elif ch == '"':
                in_string = False
        else:
            if ch == '"':
                in_string = True
            elif ch == "{":
                depth += 1
            elif ch == "}":
                depth -= 1
                if depth == 0:
                    end = i + 1
                    break
    if end is None:
        return {}
    try:
        return json.loads(html[start:end])
    except Exception:
        return {}


def collect_media_urls_from_apollo(html: str) -> set[str]:
    state = extract_apollo_state(html)
    if not state:
        return set()
    urls = set()

    def walk(value):
        if isinstance(value, dict):
            for v in value.values():
                walk(v)
        elif isinstance(value, list):
            for v in value:
                walk(v)
        elif isinstance(value, str):
            if MANIFEST_RE.search(value) or FILE_RE.search(value):
                urls.add(value)

    walk(state)
    return urls

def is_dead_window_error(e: Exception) -> bool:
    s = str(e).lower()
    return ("no such window" in s) or ("web view not found" in s)



def collect_video_descriptions_from_ld_json(html: str) -> dict:
    pattern = r'<script[^>]*id="thd-helmet__script--videoStructureData"[^>]*>(.*?)</script>'
    match = re.search(pattern, html, re.IGNORECASE | re.DOTALL)
    if not match:
        return {}
    raw = match.group(1).strip()
    try:
        data = json.loads(raw)
    except Exception:
        return {}
    graph = data.get("@graph", [])
    if not isinstance(graph, list):
        return {}
    out = {}
    for item in graph:
        if not isinstance(item, dict):
            continue
        if item.get("@type") != "VideoObject":
            continue
        url = item.get("contentUrl")
        desc = item.get("description") or item.get("name")
        date_modified = item.get("dateModified")
        if isinstance(url, str):
            out[url] = {
                "description": desc if isinstance(desc, str) else None,
                "date_modified": date_modified if isinstance(date_modified, str) else None,
            }
    return out


def build_video_list(urls: set[str], meta_map: dict) -> list[dict]:
    videos = []
    for url_val in sorted(urls):
        file_name = os.path.basename(url_val.split("?", 1)[0]) or None
        meta = meta_map.get(url_val, {})
        videos.append({
            "url": url_val,
            "type": "m3u8" if MANIFEST_RE.search(url_val) else "mp4/webm" if FILE_RE.search(url_val) else None,
            "file_name": file_name,
            "description": meta.get("description"),
            "date_modified": meta.get("date_modified"),
        })
    return videos



def choose_best_video_url(urls: set[str]) -> dict:
    """
    Prefer a master/playlist .m3u8, else a direct mp4/webm.
    Returns {type, url} or {type:'none', url:None}.
    """
    m3u8 = [u for u in urls if MANIFEST_RE.search(u)]
    mp4 = [u for u in urls if FILE_RE.search(u)]

    # Heuristic: master playlists sometimes contain "master" or "index" or have many query params.
    def score_m3u8(u: str) -> int:
        s = 0
        lu = u.lower()
        if lu.endswith("name/a.mp4/index.m3u8"):
            s += 40
        if "master" in lu: s += 5
        if "playlist" in lu: s += 3
        if "index" in lu: s += 2
        if "variant" in lu: s += 2
        # longer often means signed master url
        s += min(len(u) // 80, 4)
        return s

    if m3u8:
        best = sorted(m3u8, key=score_m3u8, reverse=True)[0]
        return {"type": "m3u8", "url": best}

    if mp4:
        # If multiple mp4s, pick longest (often best quality / signed URL)
        best = sorted(mp4, key=len, reverse=True)[0]
        return {"type": "mp4/webm", "url": best}

    return {"type": "none", "url": None}


def get_uc_driver():
    options = uc.ChromeOptions()
    options.add_argument("--no-sandbox")
    options.add_argument("--disable-blink-features=AutomationControlled")
    options.add_argument("--disable-gpu")
    options.add_argument("--start-maximized")
    options.add_argument("--disable-extensions")
    options.add_argument("--disable-infobars")
    profile_dir = tempfile.mkdtemp(prefix="uc-profile-")
    options.add_argument(f"--user-data-dir={profile_dir}")
    if UC_HEADLESS:
        options.add_argument("--headless=new")
    options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
    # driver = uc.Chrome(options=options, use_subprocess=False, driver_executable_path = ChromeDriverManager().install())
    driver = uc.Chrome(options=options, use_subprocess=False, driver_executable_path=CHROMEDRIVER)
    # driver.execute_cdp_cmd("Network.enable", {})
    time.sleep(4)
    print('UC Driver recreated session', driver.session_id)
    return driver


def _is_driver_dead(err: Exception) -> bool:
    msg = str(err).lower()
    return (
        "no such window" in msg
        or "target window already closed" in msg
        or "web view not found" in msg
        or "disconnected" in msg
        or "chrome not reachable" in msg
    )


def ensure_driver_ready(driver, max_attempts: int = 4):
    for _ in range(max_attempts):
        try:
            _ = driver.current_url
            driver.get("about:blank")
            time.sleep(0.5)
            return driver
        except Exception:
            try:
                driver.quit()
            except Exception:
                pass
            time.sleep(1.5)
            driver = get_uc_driver()
    return driver


def collect_network_urls(driver) -> set[str]:
    captured = set()
    try:
        logs = driver.get_log("performance")
    except Exception:
        logs = []
    for entry in logs:
        try:
            message = json.loads(entry.get("message", "{}")).get("message", {})
            method = message.get("method")
            params = message.get("params", {})
            url = None
            if method == "Network.requestWillBeSent":
                url = params.get("request", {}).get("url")
            elif method == "Network.responseReceived":
                url = params.get("response", {}).get("url")
            if url and (MANIFEST_RE.search(url) or FILE_RE.search(url) or (CAPTURE_SEGMENTS and SEGMENT_RE.search(url))):
                captured.add(url)
        except Exception:
            continue
    return captured


def _selenium_click(driver, element) -> bool:
    try:
        driver.execute_script("arguments[0].scrollIntoView({block: 'center'});", element)
        time.sleep(0.2)
        driver.execute_script("arguments[0].click();", element)
        return True
    except Exception:
        return False


def try_trigger_video_selenium(driver) -> bool:
    triggered = False

    driver.execute_script("window.scrollBy(0, 1400);")
    time.sleep(0.7)
    driver.execute_script("window.scrollBy(0, 1400);")
    time.sleep(0.7)

    # Open gallery modal if present
    try:
        btns = driver.find_elements(
            By.CSS_SELECTOR,
            "div.GalleryThumbnailWrapper-sc-2lpjbq-0.cbkjBo.gallery-thumbnail.shape--rounded.excess > button",
        )
        if not btns:
            btns = driver.find_elements(
                By.XPATH,
                "//div[@data-media-type and "
                "not(contains(translate(@data-media-type,'VIDEO','video'),'video')) and "
                "not(contains(translate(@data-media-type,'SPIN','spin'),'spin'))]"
                "//button[contains(@class,'MediaTilestyles__MediaTileTrigger')]",
            )
        if btns:
            _selenium_click(driver, btns[0])
            time.sleep(1.2)
    except Exception:
        pass

    # Switch to Videos tab in modal
    try:
        tabs = driver.find_elements(
            By.XPATH,
            "//div[contains(@class,'tablist')]//button[.//p[normalize-space()='Videos'] or normalize-space()='Videos']",
        )
        if tabs:
            _selenium_click(driver, tabs[0])
            time.sleep(0.8)
    except Exception:
        pass

    # Click each video thumbnail button in the modal
    try:
        thumb_buttons = driver.find_elements(By.CSS_SELECTOR, "div.gallery-thumbnail-group button")
        for btn in thumb_buttons:
            if _selenium_click(driver, btn):
                time.sleep(1.2)
                triggered = True
    except Exception:
        pass

    selectors = [
        (By.CSS_SELECTOR, "button[aria-label='Play']"),
        (By.CSS_SELECTOR, "button[title='Play']"),
        (By.CSS_SELECTOR, "[data-testid*='play']"),
        (By.XPATH, "//button[contains(normalize-space(.), 'Play')]"),
        (By.CSS_SELECTOR, "video"),
        (By.XPATH, "//*[contains(translate(@alt,'VIDEO','video'),'video')]"),
    ]

    for by, sel in selectors:
        try:
            elems = driver.find_elements(by, sel)
            if elems:
                if _selenium_click(driver, elems[0]):
                    time.sleep(1.2)
                    return True
        except Exception:
            continue

    try:
        elems = driver.find_elements(
            By.XPATH,
            "//*[contains(translate(@class,'VIDEO','video'),'video') or "
            "contains(translate(@aria-label,'VIDEO','video'),'video')]",
        )
        if elems:
            if _selenium_click(driver, elems[0]):
                time.sleep(1.2)
                return True
    except Exception:
        pass

    return triggered

def try_trigger_video(page) -> bool:
    """
    Best-effort click to cause the player to request the stream.
    Lowes UI varies; we try multiple strategies.
    """
    triggered = False

    # Scroll to load media section
    page.mouse.wheel(0, 1400)
    page.wait_for_timeout(700)
    page.mouse.wheel(0, 1400)
    page.wait_for_timeout(700)


    # Open gallery modal if present
    try:
        open_modal_btn = page.locator(
            "div.GalleryThumbnailWrapper-sc-2lpjbq-0.cbkjBo.gallery-thumbnail.shape--rounded.excess > button, "
            "div[data-media-type]:not([data-media-type*='Video' i]):not([data-media-type*='Spin' i]) "
            "button.MediaTilestyles__MediaTileTrigger-sc-vhmy2w-2.jsjQI"
        ).first
        logging.info(page.locator("button.MediaTilestyles__MediaTileTrigger-sc-vhmy2w-2.jsjQI").count())
        logging.info(page.locator("div[data-media-type]").count())
        logging.info(page.locator("div[data-media-type*='Video' i]").count())
        logging.info(page.locator("div[data-media-type*='Spin' i]").count())


        if open_modal_btn.count() > 0:
            open_modal_btn.scroll_into_view_if_needed(timeout=2000)
            page.wait_for_timeout(250)
            open_modal_btn.click(timeout=2500)
            page.wait_for_timeout(1200)
    except Exception:
        pass

    # Switch to Videos tab in modal
    try:
        videos_tab = page.locator("div.tablist button:has-text('Videos')").first
        if videos_tab.count() > 0:
            videos_tab.scroll_into_view_if_needed(timeout=2000)
            page.wait_for_timeout(250)
            videos_tab.click(timeout=2500)
            page.wait_for_timeout(800)
    except Exception:
        pass

    # Click each video thumbnail button in the modal
    try:
        thumb_buttons = page.locator(
            "div.gallery-thumbnail-group button"
        )
        for i in range(thumb_buttons.count()):
            btn = thumb_buttons.nth(i)
            btn.scroll_into_view_if_needed(timeout=2000)
            page.wait_for_timeout(200)
            btn.click(timeout=2500)
            page.wait_for_timeout(1200)
            triggered = True
    except Exception:
        pass

    selectors = [
        "button[aria-label='Play']",
        "button[title='Play']",
        "[data-testid*='play']",
        "button:has-text('Play')",
        "video",
        "img[alt*='video' i]",
    ]

    for sel in selectors:
        try:
            loc = page.locator(sel).first
            if loc.count() > 0:
                loc.scroll_into_view_if_needed(timeout=2000)
                page.wait_for_timeout(250)
                loc.click(timeout=2500)
                page.wait_for_timeout(1200)
                return True
        except Exception:
            continue

    # Sometimes videos are inside carousels/thumbnails
    # Try clicking any element that looks like a video thumbnail
    try:
        thumb = page.locator("[class*='video' i], [aria-label*='video' i]").first
        if thumb.count() > 0:
            thumb.scroll_into_view_if_needed(timeout=2000)
            page.wait_for_timeout(250)
            thumb.click(timeout=2500)
            page.wait_for_timeout(1200)
            return True
    except Exception:
        pass

    return triggered


def capture_video_for_url(page, url: str, sku: str | None = None) -> dict:
    """
    Open PDP, click play, capture stream/file URLs.
    Returns dict with best url + all captured URLs (optional).
    """
    captured = set()

    def on_request(req):
        u = req.url
        if MANIFEST_RE.search(u) or FILE_RE.search(u) or (CAPTURE_SEGMENTS and SEGMENT_RE.search(u)):
            captured.add(u)

    page.on("request", on_request)

    for attempt in range(1, MAX_RETRIES + 1):
        try:
            logging.info(f"Open {url} (attempt {attempt}/{MAX_RETRIES})")
            page.goto(url, wait_until="domcontentloaded", timeout=NAV_TIMEOUT_MS)
            page.wait_for_timeout(2500)

            if SAVE_PDP_HTML:
                html = page.content()
                os.makedirs(PDP_HTML_DIR, exist_ok=True)
                file_key = sku or re.sub(r"\W+", "_", urlparse(url).path.strip("/")) or "pdp"
                html_path = os.path.join(PDP_HTML_DIR, f"{file_key}.html")
                with open(html_path, "w", encoding="utf-8") as f:
                    f.write(html)
                logging.info(f"Saved PDP HTML to {html_path}")
            else:
                html = page.content()

            captured |= collect_media_urls_from_apollo(html)
            ld_video_desc = collect_video_descriptions_from_ld_json(html)

            if any(FILE_RE.search(u) for u in captured):
                triggered = False
            else:
                triggered = try_trigger_video(page)

            # Give it time to request manifest after click
            page.wait_for_timeout(4500 if triggered else 1500)

            best = choose_best_video_url(captured)
            best_meta = ld_video_desc.get(best["url"], {})
            videos = build_video_list(captured, ld_video_desc)
            page.remove_listener("request", on_request)

            return {
                "best_type": best["type"],
                "best_url": best["url"],
                "best_description": best_meta.get("description"),
                "best_date_modified": best_meta.get("date_modified"),
                "all_urls": sorted(captured),
                "videos": videos,
                "triggered_play": triggered,
            }

        except PWTimeoutError:
            logging.warning("Timeout, retrying...")
        except Exception as e:
            logging.warning(f"Error: {e}")

        time.sleep(5 + random.uniform(1, 5))

    page.remove_listener("request", on_request)
    return {"best_type": "none", "best_url": None, "all_urls": [], "videos": [], "triggered_play": False}


def capture_video_for_url_selenium(driver, url: str, sku: str | None = None) -> dict:
    captured = set()

    for attempt in range(1, MAX_RETRIES + 1):
        try:
            logging.info(f"Open {url} (attempt {attempt}/{MAX_RETRIES})")
            try:
                driver.get_log("performance")
            except Exception:
                pass
            driver.get(url)
            time.sleep(3)
            if VPN_ENABLED and VPN_ROTATOR is not None and parse_document_status_from_performance_logs is not None:
                try:
                    status = parse_document_status_from_performance_logs(driver.get_log("performance"), url)
                    if should_rotate_on_status is not None and should_rotate_on_status(status):
                        VPN_ROTATOR.rotate(f"HTTP {status} for {url}")
                        continue
                except Exception:
                    pass
            if VPN_ENABLED and VPN_ROTATOR is not None and is_dns_error_html is not None:
                try:
                    if is_dns_error_html(driver.page_source):
                        VPN_ROTATOR.rotate(f"DNS error for {url}")
                        continue
                except Exception:
                    pass

            if SAVE_PDP_HTML:
                os.makedirs(PDP_HTML_DIR, exist_ok=True)
                file_key = sku or re.sub(r"\W+", "_", urlparse(url).path.strip("/")) or "pdp"
                html_path = os.path.join(PDP_HTML_DIR, f"{file_key}.html")
                with open(html_path, "w", encoding="utf-8") as f:
                    f.write(driver.page_source)
                logging.info(f"Saved PDP HTML to {html_path}")
            html = driver.page_source
            captured |= collect_media_urls_from_apollo(html)
            ld_video_desc = collect_video_descriptions_from_ld_json(html)

            if any(FILE_RE.search(u) for u in captured):
                triggered = False
            else:
                triggered = try_trigger_video_selenium(driver)
            time.sleep(4.5 if triggered else 1.5)

            captured |= collect_network_urls(driver)

            best = choose_best_video_url(captured)
            best_meta = ld_video_desc.get(best["url"], {})
            videos = build_video_list(captured, ld_video_desc)
            return {
                "best_type": best["type"],
                "best_url": best["url"],
                "best_description": best_meta.get("description"),
                "best_date_modified": best_meta.get("date_modified"),
                "all_urls": sorted(captured),
                "videos": videos,
                "triggered_play": triggered,
            }
        except Exception as e:
            logging.warning(f"Error: {e}, driver session: {driver.session_id}")


        time.sleep(2 + random.uniform(0.8, 2.5))

    return {"best_type": "none", "best_url": None, "all_urls": [], "videos": [], "triggered_play": False}


def main():
    global VPN_ROTATOR
    if VPN_ENABLED and WireGuardRotator is not None:
        VPN_ROTATOR = WireGuardRotator(WG_CONFIG_DIR, cooldown_sec=WG_COOLDOWN_SEC)
        VPN_ROTATOR.ensure_up()
        atexit.register(VPN_ROTATOR.shutdown)

    # ---- DB fetch
    # try:
    #     connection = mysql.connector.connect(**db_config)
    # except mysql.connector.Error as e:
    #     logging.error(f"DB connection error: {e}")
    #     return

    # try:
        # cursor = connection.cursor(dictionary=True)
        # cursor.execute(QUERY)
        # results = cursor.fetchall()
        # df = pd.DataFrame(results)
        # df["Full_URL"] = df["LINK"]
        # urls = df["Full_URL"].dropna().tolist()
        # skus = df["SKU"].tolist() if "SKU" in df.columns else [None] * len(urls)
    pairs = load_partner_urls(PARTNER_URLS_FILE, "Homedepot URL")
    urls = [p[1] for p in pairs]
    skus = [p[0] for p in pairs]
    if not urls:
        logging.warning("No URLs loaded")
        return
    logging.info(f"Loaded {len(urls)} URLs")
    # finally:
    #     try:
    #         cursor.close()
    #         connection.close()
    #     except Exception:
    #         pass

    out = []

    driver = ensure_driver_ready(get_uc_driver())
    try:
        for i, url in enumerate(urls):
            if i and i % ROTATE_CONTEXT_EVERY == 0:
                driver.quit()
                time.sleep(1.5)
                driver = ensure_driver_ready(get_uc_driver())
            
            sku = skus[i] if i < len(skus) else None
            try:
                video_info = capture_video_for_url_selenium(driver, url, sku=sku)
            except Exception as e:
                if _is_driver_dead(e):
                    logging.warning(f"Driver died, recreating (url {i+1}/{len(urls)}): {e}")
                    try:
                        driver.quit()
                    except Exception:
                        pass
                    time.sleep(2)
                    driver = ensure_driver_ready(get_uc_driver())
                    video_info = capture_video_for_url_selenium(driver, url, sku=sku)
                else:
                    raise

            row = {
                "SKU": skus[i] if i < len(skus) else None,
                "page_url": url,
                "video_type": video_info["best_type"],
                "video_url": video_info["best_url"],
                "video_description": video_info.get("best_description"),
                "videos": video_info.get("videos", []),
                "triggered_play": video_info["triggered_play"],
                # keep this if you want debugging; remove to keep file smaller
                "all_captured_urls": video_info["all_urls"],
            }

            out.append(row)
            logging.info(f"Processed {i+1}/{len(urls)} -> {len(row['videos'])} videos")
            time.sleep(random.uniform(2.0, 4.5))
    finally:
        try:
            driver.quit()
        except Exception:
            if is_dead_window_error(Exception):
                driver.quit()
                time.sleep(4)
                driver = get_uc_driver()

    output_path = os.environ.get("OUTPUT_PATH") or OUTPUT_FILE
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(out, f, indent=2, ensure_ascii=False)

    print(json.dumps(out, indent=2, ensure_ascii=False))
    logging.info(f"Saved -> {output_path}")


if __name__ == "__main__":
    main()
