🧠

TensorFlow.jsで宿泊需要予測モデルをWebアプリに組み込む

に公開

はじめに

宿泊需要予測は収益管理の要です。Python で学習したモデルをTensorFlow.jsに変換し、ブラウザ上でリアルタイム予測を実現する方法を解説します。過去の予約データから繁忙期を予測し、動的価格設定に活用します。

モデルアーキテクチャと特徴量設計

// models/demand-predictor.js
class DemandPredictor {
  constructor() {
    this.model = null;
    this.scalers = null;
    this.featureColumns = [
      'day_of_week',      // 0-6
      'month',            // 1-12
      'day_of_month',     // 1-31
      'is_weekend',       // 0/1
      'is_holiday',       // 0/1
      'days_until_event', // 大型イベントまでの日数
      'temperature',      // 気温予報
      'lead_time',        // 予約リードタイム
      'competitor_rate',  // 競合価格指数
      'past_7d_avg',      // 過去7日平均稼働率
      'past_30d_avg',     // 過去30日平均稼働率
      'same_day_last_year' // 前年同日稼働率
    ];
  }

  async loadModel() {
    try {
      // モデルとスケーラーの読み込み
      this.model = await tf.loadLayersModel('/models/demand_forecast/model.json');
      
      // 正規化パラメータ読み込み
      const scalerResponse = await fetch('/models/demand_forecast/scalers.json');
      this.scalers = await scalerResponse.json();
      
      console.log('Model loaded successfully');
      return true;
    } catch (error) {
      console.error('Failed to load model:', error);
      return false;
    }
  }

  preprocessFeatures(rawData) {
    // 特徴量エンジニアリング
    const features = {
      day_of_week: new Date(rawData.date).getDay(),
      month: new Date(rawData.date).getMonth() + 1,
      day_of_month: new Date(rawData.date).getDate(),
      is_weekend: [0, 6].includes(new Date(rawData.date).getDay()) ? 1 : 0,
      is_holiday: this.checkHoliday(rawData.date) ? 1 : 0,
      days_until_event: this.getDaysUntilNextEvent(rawData.date),
      temperature: rawData.weather?.temperature || this.getHistoricalTemp(rawData.date),
      lead_time: this.calculateLeadTime(rawData.date),
      competitor_rate: rawData.competitorRate || 100,
      past_7d_avg: rawData.historical?.week_avg || 0.75,
      past_30d_avg: rawData.historical?.month_avg || 0.70,
      same_day_last_year: rawData.historical?.last_year || 0.72
    };

    // 正規化
    const normalizedFeatures = this.featureColumns.map(col => {
      const value = features[col];
      const scaler = this.scalers[col];
      return (value - scaler.mean) / scaler.std;
    });

    return normalizedFeatures;
  }

  async predict(dateRange, roomType = 'all') {
    if (!this.model) {
      await this.loadModel();
    }

    const predictions = [];
    
    for (let date of dateRange) {
      // 日付ごとの特徴量準備
      const rawData = await this.fetchHistoricalData(date, roomType);
      const features = this.preprocessFeatures(rawData);
      
      // テンソル変換と予測
      const inputTensor = tf.tensor2d([features]);
      const prediction = await this.model.predict(inputTensor).data();
      
      // 後処理
      const result = {
        date: date,
        predicted_occupancy: Math.min(1, Math.max(0, prediction[0])),
        confidence_interval: {
          lower: Math.max(0, prediction[0] - prediction[1] * 1.96),
          upper: Math.min(1, prediction[0] + prediction[1] * 1.96)
        },
        demand_level: this.categorizeDemand(prediction[0]),
        recommended_rate_multiplier: this.calculateRateMultiplier(prediction[0])
      };
      
      predictions.push(result);
      inputTensor.dispose();
    }

    return predictions;
  }

  categorizeDemand(occupancyRate) {
    if (occupancyRate > 0.9) return 'very_high';
    if (occupancyRate > 0.8) return 'high';
    if (occupancyRate > 0.6) return 'moderate';
    if (occupancyRate > 0.4) return 'low';
    return 'very_low';
  }

  calculateRateMultiplier(predictedOccupancy) {
    // 需要に基づく価格調整係数
    const baseMultiplier = 1.0;
    
    if (predictedOccupancy > 0.95) return baseMultiplier * 1.5;
    if (predictedOccupancy > 0.90) return baseMultiplier * 1.3;
    if (predictedOccupancy > 0.80) return baseMultiplier * 1.15;
    if (predictedOccupancy > 0.70) return baseMultiplier * 1.05;
    if (predictedOccupancy < 0.40) return baseMultiplier * 0.85;
    if (predictedOccupancy < 0.30) return baseMultiplier * 0.75;
    
    return baseMultiplier;
  }
}

Python モデルの変換とデプロイ

# convert_model.py
import tensorflowjs as tfjs
import tensorflow as tf
import json
import numpy as np

# 学習済みモデルの読み込み
model = tf.keras.models.load_model('demand_forecast_model.h5')

# TensorFlow.js形式に変換
tfjs.converters.save_keras_model(
    model,
    'public/models/demand_forecast',
    quantization_dtype=np.uint8  # モデルサイズ削減
)

# スケーラーパラメータの保存
scalers = {
    'day_of_week': {'mean': 3.0, 'std': 2.0},
    'month': {'mean': 6.5, 'std': 3.5},
    'temperature': {'mean': 20.0, 'std': 8.0},
    # ... 他の特徴量
}

with open('public/models/demand_forecast/scalers.json', 'w') as f:
    json.dump(scalers, f)

リアルタイム予測UIの実装

// components/demand-forecast.js
import { DemandPredictor } from '../models/demand-predictor.js';
import Chart from 'chart.js/auto';

class DemandForecastDashboard {
  constructor() {
    this.predictor = new DemandPredictor();
    this.chart = null;
    this.initialize();
  }

  async initialize() {
    await this.predictor.loadModel();
    this.setupEventListeners();
    this.runInitialForecast();
  }

  async runInitialForecast() {
    const startDate = new Date();
    const dateRange = this.generateDateRange(startDate, 30); // 30日予測
    
    this.showLoadingState();
    
    try {
      const predictions = await this.predictor.predict(dateRange);
      this.updateChart(predictions);
      this.updateMetrics(predictions);
      this.generateRecommendations(predictions);
    } catch (error) {
      console.error('Prediction failed:', error);
      this.showError('予測の実行に失敗しました');
    }
  }

  updateChart(predictions) {
    const ctx = document.getElementById('forecastChart').getContext('2d');
    
    if (this.chart) {
      this.chart.destroy();
    }

    this.chart = new Chart(ctx, {
      type: 'line',
      data: {
        labels: predictions.map(p => p.date),
        datasets: [
          {
            label: '予測稼働率',
            data: predictions.map(p => p.predicted_occupancy * 100),
            borderColor: 'rgb(75, 192, 192)',
            backgroundColor: 'rgba(75, 192, 192, 0.2)',
            tension: 0.4
          },
          {
            label: '信頼区間上限',
            data: predictions.map(p => p.confidence_interval.upper * 100),
            borderColor: 'rgba(75, 192, 192, 0.3)',
            borderDash: [5, 5],
            fill: false
          },
          {
            label: '信頼区間下限',
            data: predictions.map(p => p.confidence_interval.lower * 100),
            borderColor: 'rgba(75, 192, 192, 0.3)',
            borderDash: [5, 5],
            fill: false
          }
        ]
      },
      options: {
        responsive: true,
        plugins: {
          tooltip: {
            callbacks: {
              afterLabel: (context) => {
                const prediction = predictions[context.dataIndex];
                return `需要レベル: ${prediction.demand_level}\n推奨価格調整: ${(prediction.recommended_rate_multiplier * 100).toFixed(0)}%`;
              }
            }
          }
        },
        scales: {
          y: {
            beginAtZero: true,
            max: 100,
            ticks: {
              callback: (value) => value + '%'
            }
          }
        }
      }
    });
  }

  generateRecommendations(predictions) {
    const highDemandDates = predictions.filter(p => p.demand_level === 'very_high');
    const lowDemandDates = predictions.filter(p => p.demand_level === 'low' || p.demand_level === 'very_low');
    
    const recommendations = [];
    
    if (highDemandDates.length > 0) {
      recommendations.push({
        type: 'high_demand',
        message: `${highDemandDates[0].date}から${highDemandDates.length}日間は高需要が予測されます。最低宿泊日数の設定を検討してください。`,
        action: 'set_minimum_stay'
      });
    }
    
    if (lowDemandDates.length > 0) {
      recommendations.push({
        type: 'low_demand',
        message: `${lowDemandDates[0].date}は需要が低い予測です。プロモーション実施を推奨します。`,
        action: 'create_promotion'
      });
    }
    
    this.displayRecommendations(recommendations);
  }

  async exportPredictions(format = 'csv') {
    const predictions = await this.predictor.predict(this.generateDateRange(new Date(), 90));
    
    if (format === 'csv') {
      const csv = this.convertToCSV(predictions);
      this.downloadFile(csv, 'demand_forecast.csv', 'text/csv');
    } else if (format === 'json') {
      const json = JSON.stringify(predictions, null, 2);
      this.downloadFile(json, 'demand_forecast.json', 'application/json');
    }
  }
}

// 初期化
document.addEventListener('DOMContentLoaded', () => {
  new DemandForecastDashboard();
});

まとめ

TensorFlow.jsを使用することで、サーバーへの負荷なくブラウザ上で需要予測が実行できます。リアルタイムな特徴量更新と即座の予測により、収益管理担当者が機動的な価格戦略を立てられるようになります。

次回は、このシステムと連携する決済処理の実装について解説します。

Discussion