from concurrent.futures import ProcessPoolExecutor, as_completed
from facefusion.ffmpeg import concat_video
from facefusion.filesystem import are_images, are_videos, move_file, remove_file, is_file
from facefusion.jobs import job_helper, job_manager
from facefusion.types import JobOutputSet, JobStep, ProcessStep
from typing import List
from facefusion import state_manager


def run_job(job_id: str, process_step: ProcessStep) -> bool:
	queued_job_ids = job_manager.find_job_ids('queued')

	if job_id in queued_job_ids:
		# Execute all steps, even if some fail
		steps_succeeded = run_steps(job_id, process_step)
		# Always attempt to finalize any successful outputs
		finalized = finalize_steps(job_id)
		clean_steps(job_id)
		if steps_succeeded and finalized:
			return job_manager.move_job_file(job_id, 'completed')
		job_manager.move_job_file(job_id, 'failed')
	return False


def run_jobs(process_step: ProcessStep, halt_on_error: bool) -> bool:
	queued_job_ids: List[str] = job_manager.find_job_ids('queued')
	has_error = False

	if not queued_job_ids:
		return False

	# Use a fixed-size pool; each worker processes multiple jobs sequentially to reuse models
	max_workers = state_manager.get_item('execution_queue_count') or 1
	max_workers = max(1, int(max_workers))
	max_workers = min(max_workers, len(queued_job_ids))

	job_batches: List[List[str]] = [[] for _ in range(max_workers)]
	for index, job_id in enumerate(queued_job_ids):
		job_batches[index % max_workers].append(job_id)

	with ProcessPoolExecutor(max_workers=max_workers) as executor:
	futures = [executor.submit(run_job_batch, batch, process_step, halt_on_error)
            for batch in job_batches if batch]
	for future in as_completed(futures):
		try:
				result = future.result()
				if not result:
					has_error = True
			except Exception:
				has_error = True
	return not has_error


def run_job_batch(job_ids: List[str], process_step: ProcessStep, halt_on_error: bool) -> bool:
	has_error = False
	# Ensure tolerant strategy inside worker to minimize idle and preserve caches
	try:
		previous_strategy = state_manager.get_item('video_memory_strategy')
		state_manager.set_item('video_memory_strategy', 'tolerant')
	except Exception:
		previous_strategy = None
	for job_id in job_ids:
		try:
			result = run_job(job_id, process_step)
			if not result:
				has_error = True
				if halt_on_error:
					break
		except Exception:
			has_error = True
			if halt_on_error:
				break
	# restore previous strategy if available
	if previous_strategy is not None:
		state_manager.set_item('video_memory_strategy', previous_strategy)
	return not has_error


def retry_job(job_id: str, process_step: ProcessStep) -> bool:
	failed_job_ids = job_manager.find_job_ids('failed')

	if job_id in failed_job_ids:
		return job_manager.set_steps_status(job_id, 'queued') and job_manager.move_job_file(job_id, 'queued') and run_job(job_id, process_step)
	return False


def retry_jobs(process_step: ProcessStep, halt_on_error: bool) -> bool:
	failed_job_ids: List[str] = job_manager.find_job_ids('failed')
	has_error = False

	if not failed_job_ids:
		return False

	max_workers = state_manager.get_item('execution_queue_count') or 1
	max_workers = max(1, int(max_workers))
	max_workers = min(max_workers, len(failed_job_ids))

	job_batches: List[List[str]] = [[] for _ in range(max_workers)]
	for index, job_id in enumerate(failed_job_ids):
		job_batches[index % max_workers].append(job_id)

	with ProcessPoolExecutor(max_workers=max_workers) as executor:
	futures = [executor.submit(retry_job_batch, batch, process_step, halt_on_error)
            for batch in job_batches if batch]
	for future in as_completed(futures):
		try:
				result = future.result()
				if not result:
					has_error = True
			except Exception:
				has_error = True
	return not has_error


def retry_job_batch(job_ids: List[str], process_step: ProcessStep, halt_on_error: bool) -> bool:
	has_error = False
	for job_id in job_ids:
		try:
			result = retry_job(job_id, process_step)
			if not result:
				has_error = True
				if halt_on_error:
					break
		except Exception:
			has_error = True
			if halt_on_error:
				break
	return not has_error


def run_step(job_id: str, step_index: int, step: JobStep, process_step: ProcessStep) -> bool:
	step_args = step.get('args')

	if job_manager.set_step_status(job_id, step_index, 'started') and process_step(job_id, step_index, step_args):
		output_path = step_args.get('output_path')
		step_output_path = job_helper.get_step_output_path(
			job_id, step_index, output_path)

		return move_file(output_path, step_output_path) and job_manager.set_step_status(job_id, step_index, 'completed')
	job_manager.set_step_status(job_id, step_index, 'failed')
	return False


def run_steps(job_id: str, process_step: ProcessStep) -> bool:
	steps = job_manager.get_steps(job_id)

	if steps:
		has_error = False
		for index, step in enumerate(steps):
			# Continue executing subsequent steps even if a previous one failed
			if not run_step(job_id, index, step, process_step):
				has_error = True
		return not has_error
	return False


def finalize_steps(job_id: str) -> bool:
	output_set = collect_output_set(job_id)

	for output_path, temp_output_paths in output_set.items():
		# Only consider files that actually exist; failed steps won't have outputs
		present_temp_output_paths = [
			path for path in temp_output_paths if is_file(path)]

		# Nothing to do for this output if no temp files exist
		if not present_temp_output_paths:
			continue

		if are_videos(present_temp_output_paths):
			if not concat_video(output_path, present_temp_output_paths):
				return False
		elif are_images(present_temp_output_paths):
			for temp_output_path in present_temp_output_paths:
				if not move_file(temp_output_path, output_path):
					return False
	return True


def clean_steps(job_id: str) -> bool:
	output_set = collect_output_set(job_id)

	for temp_output_paths in output_set.values():
		for temp_output_path in temp_output_paths:
			if not remove_file(temp_output_path):
				return False
	return True


def collect_output_set(job_id: str) -> JobOutputSet:
	steps = job_manager.get_steps(job_id)
	job_output_set: JobOutputSet = {}

	for index, step in enumerate(steps):
		output_path = step.get('args').get('output_path')

		if output_path:
			step_output_path = job_manager.get_step_output_path(
				job_id, index, output_path)
			job_output_set.setdefault(output_path, []).append(step_output_path)
	return job_output_set
