305 lines
8.7 KiB
Python
305 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Main execution script for Dynamic Traffic Signal Optimization using RL
|
|
M.Tech Project Implementation
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import yaml
|
|
import logging
|
|
from datetime import datetime
|
|
import torch
|
|
import numpy as np
|
|
|
|
# Add src to path
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
|
|
|
from src.training.trainer import TrafficTrainer
|
|
from src.environment.traffic_environment import AdvancedTrafficEnv
|
|
from src.agents.advanced_dqn_agent import AdvancedDQNAgent
|
|
|
|
def setup_logging():
|
|
"""Setup global logging configuration"""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('main.log'),
|
|
logging.StreamHandler(sys.stdout)
|
|
]
|
|
)
|
|
|
|
def create_directories():
|
|
"""Create necessary project directories"""
|
|
directories = [
|
|
'models', 'models/checkpoints', 'data', 'logs', 'logs/tensorboard',
|
|
'results', 'results/plots', 'results/analysis', 'sumo_configs'
|
|
]
|
|
|
|
for directory in directories:
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
def train_model(config_path: str, resume_checkpoint: str = None):
|
|
"""Train the RL model"""
|
|
print("="*80)
|
|
print("STARTING TRAINING MODE")
|
|
print("="*80)
|
|
|
|
# Initialize components
|
|
env = AdvancedTrafficEnv(config_path)
|
|
agent = AdvancedDQNAgent(config_path)
|
|
trainer = TrafficTrainer(config_path)
|
|
|
|
if resume_checkpoint:
|
|
print(f"Resuming training from checkpoint: {resume_checkpoint}")
|
|
agent.load(resume_checkpoint)
|
|
else:
|
|
print("Starting fresh training...")
|
|
|
|
# Start training
|
|
training_results = trainer.train(env, agent)
|
|
|
|
# Cleanup
|
|
env.close()
|
|
|
|
print("\n" + "="*80)
|
|
print("BASELINE BENCHMARKING COMPLETED")
|
|
print("="*80)
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Dynamic Traffic Signal Optimization using RL - M.Tech Project'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--mode',
|
|
choices=['train', 'test', 'evaluate', 'benchmark'],
|
|
required=True,
|
|
help='Execution mode'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--config',
|
|
type=str,
|
|
default='config/config.yaml',
|
|
help='Path to configuration file'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
default='models/final_model.pth',
|
|
help='Path to model file (for test/evaluate modes)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--episodes',
|
|
type=int,
|
|
default=10,
|
|
help='Number of test episodes'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--resume',
|
|
type=str,
|
|
default=None,
|
|
help='Path to checkpoint for resuming training'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--gpu',
|
|
action='store_true',
|
|
help='Force GPU usage if available'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--debug',
|
|
action='store_true',
|
|
help='Enable debug logging'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
|
|
if args.debug:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
# Create directories
|
|
create_directories()
|
|
|
|
# Check configuration file
|
|
if not os.path.exists(args.config):
|
|
print(f"Error: Configuration file not found at {args.config}")
|
|
print("Please create config/config.yaml or specify correct path with --config")
|
|
sys.exit(1)
|
|
|
|
# GPU setup
|
|
if args.gpu and torch.cuda.is_available():
|
|
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
|
elif args.gpu:
|
|
print("GPU requested but not available, using CPU")
|
|
else:
|
|
print("Using CPU")
|
|
|
|
# Print system information
|
|
print(f"\nStarting execution at: {datetime.now()}")
|
|
print(f"Mode: {args.mode}")
|
|
print(f"Config: {args.config}")
|
|
if args.mode in ['test', 'evaluate']:
|
|
print(f"Model: {args.model}")
|
|
print(f"Python version: {sys.version}")
|
|
print(f"PyTorch version: {torch.__version__}")
|
|
|
|
# Execute based on mode
|
|
try:
|
|
if args.mode == 'train':
|
|
train_model(args.config, args.resume)
|
|
|
|
elif args.mode == 'test':
|
|
test_model(args.config, args.model, args.episodes)
|
|
|
|
elif args.mode == 'evaluate':
|
|
evaluate_model(args.config, args.model)
|
|
|
|
elif args.mode == 'benchmark':
|
|
benchmark_baselines(args.config)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n\nExecution interrupted by user")
|
|
sys.exit(0)
|
|
|
|
except Exception as e:
|
|
print(f"\nError during execution: {e}")
|
|
if args.debug:
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
print(f"\nExecution completed at: {datetime.now()}")
|
|
|
|
if __name__ == "__main__":
|
|
main()agent.close()
|
|
|
|
print("\n" + "="*80)
|
|
print("TRAINING COMPLETED")
|
|
print("="*80)
|
|
print(f"Total Episodes: {training_results['total_episodes']}")
|
|
print(f"Training Time: {training_results['total_training_time']:.2f} seconds")
|
|
print(f"Best Reward: {training_results['best_reward']:.2f}")
|
|
print(f"Best Eval Score: {training_results['best_eval_score']:.2f}")
|
|
print(f"Final Epsilon: {training_results.get('final_epsilon', 0):.4f}")
|
|
|
|
def test_model(config_path: str, model_path: str, episodes: int = 10):
|
|
"""Test a trained model"""
|
|
print("="*80)
|
|
print("STARTING TESTING MODE")
|
|
print("="*80)
|
|
|
|
env = AdvancedTrafficEnv(config_path)
|
|
agent = AdvancedDQNAgent(config_path)
|
|
|
|
if not os.path.exists(model_path):
|
|
print(f"Error: Model file not found at {model_path}")
|
|
return
|
|
|
|
agent.load(model_path)
|
|
print(f"Model loaded from: {model_path}")
|
|
|
|
total_rewards = []
|
|
episode_summaries = []
|
|
|
|
for episode in range(episodes):
|
|
print(f"\nTesting Episode {episode + 1}/{episodes}")
|
|
|
|
state = env.reset()
|
|
total_reward = 0
|
|
steps = 0
|
|
|
|
while True:
|
|
action = agent.act(state, training=False)
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
state = next_state
|
|
total_reward += reward
|
|
steps += 1
|
|
|
|
if done:
|
|
break
|
|
|
|
episode_summary = env.get_episode_summary()
|
|
episode_summary['total_reward'] = total_reward
|
|
episode_summary['steps'] = steps
|
|
|
|
total_rewards.append(total_reward)
|
|
episode_summaries.append(episode_summary)
|
|
|
|
print(f" Reward: {total_reward:.2f}")
|
|
print(f" Steps: {steps}")
|
|
print(f" Avg Delay: {episode_summary.get('average_delay', 0):.2f}s")
|
|
print(f" Throughput: {episode_summary.get('total_throughput', 0):.0f}")
|
|
|
|
env.close()
|
|
agent.close()
|
|
|
|
# Print summary statistics
|
|
print("\n" + "="*80)
|
|
print("TESTING RESULTS SUMMARY")
|
|
print("="*80)
|
|
print(f"Average Reward: {np.mean(total_rewards):.2f} ± {np.std(total_rewards):.2f}")
|
|
print(f"Average Delay: {np.mean([s.get('average_delay', 0) for s in episode_summaries]):.2f}s")
|
|
print(f"Average Throughput: {np.mean([s.get('total_throughput', 0) for s in episode_summaries]):.0f}")
|
|
print(f"Average Queue Length: {np.mean([s.get('average_queue_length', 0) for s in episode_summaries]):.2f}")
|
|
|
|
def evaluate_model(config_path: str, model_path: str):
|
|
"""Comprehensive model evaluation"""
|
|
print("="*80)
|
|
print("STARTING COMPREHENSIVE EVALUATION")
|
|
print("="*80)
|
|
|
|
env = AdvancedTrafficEnv(config_path)
|
|
agent = AdvancedDQNAgent(config_path)
|
|
|
|
if not os.path.exists(model_path):
|
|
print(f"Error: Model file not found at {model_path}")
|
|
return
|
|
|
|
agent.load(model_path)
|
|
print(f"Model loaded from: {model_path}")
|
|
|
|
# Run comprehensive evaluation
|
|
# This would call the evaluator component
|
|
print("Running comprehensive evaluation...")
|
|
|
|
env.close()
|
|
agent.close()
|
|
|
|
print("\n" + "="*80)
|
|
print("EVALUATION COMPLETED")
|
|
print("="*80)
|
|
print("Results saved to results/ directory")
|
|
|
|
def benchmark_baselines(config_path: str):
|
|
"""Benchmark baseline methods only"""
|
|
print("="*80)
|
|
print("BENCHMARKING BASELINE METHODS")
|
|
print("="*80)
|
|
|
|
env = AdvancedTrafficEnv(config_path)
|
|
|
|
with open(config_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
baseline_methods = config['evaluation']['baseline_methods']
|
|
baseline_results = {}
|
|
|
|
for baseline in baseline_methods:
|
|
print(f"\nEvaluating baseline: {baseline}")
|
|
# This would implement baseline evaluation logic
|
|
print(f" Results for {baseline}: [Implementation needed]")
|
|
|
|
env.close()
|