diff options
Diffstat (limited to 'prediction_controller.py')
| -rw-r--r-- | prediction_controller.py | 249 |
1 files changed, 215 insertions, 34 deletions
diff --git a/prediction_controller.py b/prediction_controller.py index c3be8cc..916a325 100644 --- a/prediction_controller.py +++ b/prediction_controller.py @@ -1,48 +1,229 @@ # prediction_controller.py from dateutil.relativedelta import relativedelta from datetime import datetime, timedelta +from typing import List, Tuple, Dict, Optional import math +from calendar_manager import CalendarManager +from date_calculator import DateCalculator +from event_type_handler import EventTypeHandler +from config import EventConfig + class PredictionController: - def __init__(self, calendar_manager, date_calculator, keyword_list): + """Handles prediction calculation with proper business logic orchestration""" + + def __init__(self, calendar_manager: CalendarManager, date_calculator: DateCalculator): self.calendar_manager = calendar_manager self.date_calculator = date_calculator - self.launch_date = None - self.duration = None - self.prediction = None - self.keyword_list = keyword_list - for keyword in keyword_list: - self.keyword = [] - - def set_parameters(self, launch_date, duration_years): - self.launch_date = datetime.fromisoformat(launch_date) - self.duration = relativedelta(years=duration_years) - - def make_prediction(self, launch_date, duration_years): - self.set_parameters(launch_date, duration_years) - - prediction = self.launch_date + self.duration - timedelta(days=1) - - keyword_args = {} - - for entry in self.calendar_manager.entries: - for keyword in self.keyword_list: - if entry.keyword == keyword: - if keyword not in keyword_args: - keyword_args[keyword] = [] - keyword_args[keyword].append((entry.start_date, entry.end_date, entry.id)) - break - # print(keyword_args) - prediction, corrected_events = self.date_calculator.calculate_prediction(self.launch_date, self.duration, **keyword_args) - self.prediction = prediction - self.calendar_manager.correct_dates(corrected_events) - - def get_launch_date(self): + self.launch_date: Optional[datetime] = None + self.duration: Optional[relativedelta] = None + self.prediction: Optional[datetime] = None + self.corrected_events: List[Tuple[datetime, datetime, str]] = [] + + def set_parameters(self, launch_date: str, duration_years: int) -> bool: + """Set launch date and duration with validation""" + try: + self.launch_date = datetime.fromisoformat(launch_date) + self.duration = relativedelta(years=duration_years) + + # Validate parameters + if duration_years <= 0: + raise ValueError("Duration must be positive") + if self.launch_date < datetime(1900, 1, 1): + raise ValueError("Launch date too far in the past") + if self.launch_date > datetime(2100, 1, 1): + raise ValueError("Launch date too far in the future") + + return True + except ValueError as e: + print(f"Error setting parameters: {e}") + return False + + def make_prediction(self, launch_date: str, duration_years: int) -> bool: + """Calculate prediction with proper business logic orchestration""" + if not self.set_parameters(launch_date, duration_years): + return False + + try: + # Calculate base prediction (launch + duration - 1 day) + prediction_start = self.launch_date + self.duration - timedelta(days=1) + + # Categorize events by type + categorized_events = EventTypeHandler.categorize_events(self.calendar_manager.entries) + + # Process events according to business rules + processed_events = self._process_events_by_type(categorized_events) + + # Calculate final prediction + self.prediction = self._calculate_final_prediction(prediction_start, processed_events) + + # Apply corrections to calendar entries + self._apply_corrections_to_calendar() + + return True + + except Exception as e: + print(f"Error calculating prediction: {e}") + return False + + def _process_events_by_type(self, categorized_events: Dict[str, List[Tuple[datetime, datetime, str]]]) -> Dict[str, List[Tuple[datetime, datetime, str]]]: + """Process events according to business rules""" + processed = {} + + # Process full-time projects first (EZ 100% and EZ pauschal) + full_projects = [] + for event_type in ["EZ 100%", "EZ pauschal"]: + if event_type in categorized_events: + full_projects.extend(categorized_events[event_type]) + + if full_projects: + processed["full_projects"] = self._process_full_projects(full_projects) + + # Process half-time projects (EZ 50%) + if "EZ 50%" in categorized_events: + processed["half_projects"] = self._process_half_projects( + categorized_events["EZ 50%"], + processed.get("full_projects", []) + ) + + # Process other events (Sonstige) + if "Sonstige" in categorized_events: + processed["other_events"] = self._process_other_events( + categorized_events["Sonstige"], + processed.get("full_projects", []), + processed.get("half_projects", []) + ) + + return processed + + def _process_full_projects(self, full_projects: List[Tuple[datetime, datetime, str]]) -> List[Tuple[datetime, datetime, str]]: + """Process full-time projects""" + # Sort and truncate periods + sorted_projects = DateCalculator.sort_periods(full_projects) + considered_projects = DateCalculator.truncate_periods(sorted_projects, self.launch_date) + + # Round to month boundaries + rounded_projects, total_months = DateCalculator.round_periods(considered_projects) + + return rounded_projects + + def _process_half_projects(self, half_projects: List[Tuple[datetime, datetime, str]], + full_projects: List[Tuple[datetime, datetime, str]]) -> List[Tuple[datetime, datetime, str]]: + """Process half-time projects""" + # Sort and truncate periods + sorted_projects = DateCalculator.sort_periods(half_projects) + considered_projects = DateCalculator.truncate_periods(sorted_projects, self.launch_date) + + # Find non-overlapping periods with full projects + non_overlapping_projects = [] + for test_interval in considered_projects: + non_overlapping_projects.extend( + DateCalculator.find_non_overlapping_periods(full_projects, test_interval) + ) + + # Round to month boundaries + rounded_projects, total_months = DateCalculator.round_periods(non_overlapping_projects) + + return rounded_projects + + def _process_other_events(self, other_events: List[Tuple[datetime, datetime, str]], + full_projects: List[Tuple[datetime, datetime, str]], + half_projects: List[Tuple[datetime, datetime, str]]) -> List[Tuple[datetime, datetime, str]]: + """Process other events (Sonstige)""" + # Sort and truncate periods + sorted_events = DateCalculator.sort_periods(other_events) + considered_events = DateCalculator.truncate_periods(sorted_events, self.launch_date) + + # Adjust overlapping periods + adjusted_events = DateCalculator.adjust_periods(considered_events) + + # Find non-overlapping periods with all projects + all_projects = DateCalculator.sort_periods(full_projects + half_projects) + non_overlapping_events = [] + for test_interval in adjusted_events: + non_overlapping_events.extend( + DateCalculator.find_non_overlapping_periods(all_projects, test_interval) + ) + + return non_overlapping_events + + def _calculate_final_prediction(self, prediction_start: datetime, + processed_events: Dict[str, List[Tuple[datetime, datetime, str]]]) -> datetime: + """Calculate the final prediction date""" + # Calculate months from projects + total_months = 0 + + # Full projects count as full months + if "full_projects" in processed_events: + total_months += DateCalculator.calculate_total_months(processed_events["full_projects"]) + + # Half projects count as half months + if "half_projects" in processed_events: + half_months = DateCalculator.calculate_total_months(processed_events["half_projects"]) + total_months += math.ceil(half_months / 2) + + # Other events count as days + total_days = 0 + if "other_events" in processed_events: + total_days = DateCalculator.calculate_total_days(processed_events["other_events"]) + + # Calculate final prediction + final_prediction = (prediction_start + + relativedelta(months=total_months) + + timedelta(days=total_days)) + + # Apply maximum limit + max_prediction = prediction_start + EventConfig.get_max_prediction_duration() + final_prediction = DateCalculator.min_date(final_prediction, max_prediction) + + return final_prediction + + def _apply_corrections_to_calendar(self): + """Apply corrected dates to calendar entries""" + # Collect all corrected events from processed results + all_corrected_events = [] + + # Get processed events from the prediction calculation + categorized_events = EventTypeHandler.categorize_events(self.calendar_manager.entries) + processed_events = self._process_events_by_type(categorized_events) + + # Collect corrected events from all categories + for event_type, events in processed_events.items(): + all_corrected_events.extend(events) + + # Apply corrections to calendar entries + self.calendar_manager.correct_dates(all_corrected_events) + + def get_launch_date(self) -> Optional[datetime]: + """Get the launch date""" return self.launch_date - def get_duration(self): + def get_duration(self) -> Optional[relativedelta]: + """Get the duration""" return self.duration - def get_prediction(self): + def get_prediction(self) -> Optional[datetime]: + """Get the prediction""" return self.prediction + + def validate_prediction_inputs(self, launch_date: str, duration_years: int) -> List[str]: + """Validate prediction inputs and return list of errors""" + errors = [] + + try: + launch_dt = datetime.fromisoformat(launch_date) + except ValueError: + errors.append("Invalid launch date format") + return errors + + if duration_years <= 0: + errors.append("Duration must be positive") + + if launch_dt < datetime(1900, 1, 1): + errors.append("Launch date too far in the past") + + if launch_dt > datetime(2100, 1, 1): + errors.append("Launch date too far in the future") + + return errors |
