from datetime import datetime, timedelta from dateutil.relativedelta import relativedelta from typing import List, Tuple class DateCalculator: """Pure mathematical operations for date and period calculations""" @staticmethod def sort_periods(periods: List[Tuple[datetime, datetime, str]]) -> List[Tuple[datetime, datetime, str]]: """Sort periods by start date, then end date""" return sorted(periods, key=lambda p: (p[0], p[1])) @staticmethod def truncate_periods(periods: List[Tuple[datetime, datetime, str]], launch_date: datetime) -> List[Tuple[datetime, datetime, str]]: """Truncate periods to start from launch date""" considered_periods = [] for start, end, period_id in periods: truncated_start = max(start, launch_date) if truncated_start <= end: considered_periods.append((truncated_start, end, period_id)) return considered_periods @staticmethod def round_periods(periods: List[Tuple[datetime, datetime, str]]) -> Tuple[List[Tuple[datetime, datetime, str]], int]: """Round periods to month boundaries and calculate total months""" rounded_periods = [] total_months = 0 last_end = None for start, end, period_id in periods: if last_end and start <= last_end: start = last_end + timedelta(days=1) if start > end: continue year_diff = end.year - start.year month_diff = end.month - start.month months = year_diff * 12 + month_diff if end.day >= start.day: months += 1 rounded_end = start + relativedelta(months=months) - timedelta(days=1) rounded_periods.append((start, rounded_end, period_id)) total_months += months last_end = rounded_end return rounded_periods, total_months @staticmethod def adjust_periods(periods: List[Tuple[datetime, datetime, str]]) -> List[Tuple[datetime, datetime, str]]: """Adjust overlapping periods without merging. - Later periods overlapping with a previous one have their start moved to the previous end + 1 day. - Periods fully contained in a previous one are discarded. """ if not periods: return [] adjusted = [] for start, end, period_id in periods: if not adjusted: adjusted.append((start, end, period_id)) continue last_start, last_end, last_pid = adjusted[-1] if start <= last_end: # Fully contained in previous period \u2192 discard if end <= last_end: continue # Overlaps head; push start to the day after last_end new_start = last_end + timedelta(days=1) if new_start <= end: adjusted.append((new_start, end, period_id)) # else new_start > end \u2192 discard else: adjusted.append((start, end, period_id)) return adjusted @staticmethod def find_non_overlapping_periods(existing_periods: List[Tuple[datetime, datetime, str]], test_period: Tuple[datetime, datetime, str]) -> List[Tuple[datetime, datetime, str]]: """Find non-overlapping parts of a test period against existing periods""" test_start, test_end, period_id = test_period non_overlapping_periods = [] for start, end, _ in existing_periods: if test_end < start: non_overlapping_periods.append((test_start, test_end, period_id)) return non_overlapping_periods elif test_start > end: continue else: if test_start < start: non_overlapping_periods.append((test_start, start - timedelta(days=1), period_id)) test_start = end + timedelta(days=1) if test_start <= test_end: non_overlapping_periods.append((test_start, test_end, period_id)) return non_overlapping_periods @staticmethod def filter_valid_periods(periods: List[Tuple[datetime, datetime, str]]) -> List[Tuple[datetime, datetime, str]]: """Filter out periods where start date is after end date""" return [(start, end, period_id) for start, end, period_id in periods if start <= end] @staticmethod def calculate_total_days(periods: List[Tuple[datetime, datetime, str]]) -> int: """Calculate total days across all periods""" return sum((end - start).days + 1 for start, end, _ in periods) @staticmethod def calculate_total_months(periods: List[Tuple[datetime, datetime, str]]) -> int: """Calculate total months across all periods""" total_months = 0 for start, end, _ in periods: year_diff = end.year - start.year month_diff = end.month - start.month months = year_diff * 12 + month_diff if end.day >= start.day: months += 1 total_months += months return total_months @staticmethod def add_months_to_date(date: datetime, months: int) -> datetime: """Add months to a date""" return date + relativedelta(months=months) @staticmethod def add_days_to_date(date: datetime, days: int) -> datetime: """Add days to a date""" return date + timedelta(days=days) @staticmethod def min_date(date1: datetime, date2: datetime) -> datetime: """Return the minimum of two dates""" return min(date1, date2)