| """ |
| Flask server for real-time 3D motion generation demo (HF Space version) |
| """ |
|
|
| import argparse |
| import threading |
| import time |
|
|
| from flask import Flask, jsonify, render_template, request |
| from flask_cors import CORS |
| from model_manager import get_model_manager |
|
|
|
|
| def _coerce_value(value, reference): |
| """Coerce a value to match the type of a reference value""" |
| if isinstance(reference, bool): |
| return value if isinstance(value, bool) else str(value).lower() in ("true", "1") |
| elif isinstance(reference, int): |
| return int(value) |
| elif isinstance(reference, float): |
| return float(value) |
| return str(value) |
|
|
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| model_manager = None |
| model_name_global = None |
|
|
| |
| active_session_id = None |
| session_lock = threading.Lock() |
|
|
| |
| last_frame_consumed_time = None |
| consumption_timeout = ( |
| 5.0 |
| ) |
| consumption_monitor_thread = None |
| consumption_monitor_lock = threading.Lock() |
|
|
|
|
| def init_model(): |
| """Initialize model manager""" |
| global model_manager |
| if model_manager is None: |
| if model_name_global is None: |
| raise RuntimeError( |
| "model_name_global not set. Server not properly initialized." |
| ) |
| print(f"Initializing model manager with model: {model_name_global}") |
| model_manager = get_model_manager(model_name=model_name_global) |
| print("Model manager ready!") |
| return model_manager |
|
|
|
|
| def consumption_monitor(): |
| """Monitor frame consumption and auto-reset if client stops consuming""" |
| global last_frame_consumed_time, active_session_id, model_manager |
|
|
| while True: |
| time.sleep(2.0) |
|
|
| |
| should_reset = False |
| current_session = None |
| time_since_last_consumption = 0 |
|
|
| |
| with consumption_monitor_lock: |
| if last_frame_consumed_time is not None: |
| time_since_last_consumption = time.time() - last_frame_consumed_time |
| if time_since_last_consumption > consumption_timeout: |
| |
| if model_manager and model_manager.is_generating: |
| should_reset = True |
|
|
| |
| if should_reset: |
| with session_lock: |
| current_session = active_session_id |
|
|
| |
| if should_reset and current_session is not None: |
| print( |
| f"No frame consumed for {time_since_last_consumption:.1f}s - client disconnected, auto-resetting..." |
| ) |
|
|
| if model_manager: |
| model_manager.reset() |
| print( |
| "Generation reset due to client disconnect (no frame consumption)" |
| ) |
|
|
| |
| with session_lock: |
| if active_session_id == current_session: |
| active_session_id = None |
|
|
| with consumption_monitor_lock: |
| last_frame_consumed_time = None |
|
|
|
|
| def start_consumption_monitor(): |
| """Start the consumption monitoring thread if not already running""" |
| global consumption_monitor_thread |
|
|
| if consumption_monitor_thread is None or not consumption_monitor_thread.is_alive(): |
| consumption_monitor_thread = threading.Thread( |
| target=consumption_monitor, daemon=True |
| ) |
| consumption_monitor_thread.start() |
| print("Consumption monitor started") |
|
|
|
|
| @app.route("/") |
| def index(): |
| """Main page""" |
| return render_template("index.html") |
|
|
|
|
| @app.route("/api/config", methods=["GET"]) |
| def get_config(): |
| """Get current config""" |
| try: |
| if model_manager: |
| status = model_manager.get_buffer_status() |
| return jsonify( |
| { |
| "schedule_config": status["schedule_config"], |
| "cfg_config": status["cfg_config"], |
| "history_length": status["history_length"], |
| "smoothing_alpha": float(status["smoothing_alpha"]), |
| } |
| ) |
| else: |
| |
| return jsonify( |
| { |
| "schedule_config": {}, |
| "cfg_config": {}, |
| "history_length": 30, |
| "smoothing_alpha": 0.5, |
| } |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/config", methods=["POST"]) |
| def update_config(): |
| """Update model config in memory""" |
| try: |
| global active_session_id, last_frame_consumed_time |
|
|
| if not model_manager or not model_manager.model: |
| return jsonify({"status": "error", "message": "Model not loaded yet"}), 400 |
|
|
| data = request.json |
| new_schedule_config = data.get("schedule_config") |
| new_cfg_config = data.get("cfg_config") |
| history_length = data.get("history_length") |
| smoothing_alpha = data.get("smoothing_alpha") |
|
|
| valid_schedule_keys = set(model_manager._base_schedule_config.keys()) |
| valid_cfg_keys = set(model_manager._base_cfg_config.keys()) |
|
|
| |
| if new_schedule_config: |
| for key in new_schedule_config: |
| if key not in valid_schedule_keys: |
| return jsonify( |
| { |
| "status": "error", |
| "message": f"Unknown schedule_config key: {key}", |
| } |
| ), 400 |
| for key, value in new_schedule_config.items(): |
| model_manager._base_schedule_config[key] = _coerce_value( |
| value, model_manager._base_schedule_config[key] |
| ) |
|
|
| |
| if new_cfg_config: |
| for key in new_cfg_config: |
| if key not in valid_cfg_keys: |
| return jsonify( |
| {"status": "error", "message": f"Unknown cfg_config key: {key}"} |
| ), 400 |
| for key, value in new_cfg_config.items(): |
| model_manager._base_cfg_config[key] = _coerce_value( |
| value, model_manager._base_cfg_config[key] |
| ) |
|
|
| |
| model_manager.reset( |
| history_length=history_length, |
| smoothing_alpha=smoothing_alpha, |
| ) |
|
|
| |
| with session_lock: |
| active_session_id = None |
| with consumption_monitor_lock: |
| last_frame_consumed_time = None |
|
|
| return jsonify({"status": "success"}) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/start", methods=["POST"]) |
| def start_generation(): |
| """Start generation with given text""" |
| try: |
| global active_session_id, last_frame_consumed_time |
|
|
| data = request.json |
| session_id = data.get("session_id") |
| text = data.get("text", "walk in a circle.") |
| history_length = data.get("history_length") |
| smoothing_alpha = data.get( |
| "smoothing_alpha", None |
| ) |
| force = data.get("force", False) |
|
|
| if not session_id: |
| return jsonify( |
| {"status": "error", "message": "session_id is required"} |
| ), 400 |
|
|
| print( |
| f"[Session {session_id}] Starting generation with text: {text}, history_length: {history_length}, force: {force}" |
| ) |
|
|
| |
| mm = init_model() |
|
|
| |
| need_force_takeover = False |
|
|
| with session_lock: |
| if active_session_id and active_session_id != session_id: |
| if not force: |
| |
| return jsonify( |
| { |
| "status": "error", |
| "message": "Another session is already generating.", |
| "conflict": True, |
| "active_session_id": active_session_id, |
| } |
| ), 409 |
| else: |
| |
| print( |
| f"[Session {session_id}] Force takeover from session {active_session_id}" |
| ) |
| need_force_takeover = True |
|
|
| if mm.is_generating and active_session_id == session_id: |
| return jsonify( |
| { |
| "status": "error", |
| "message": "Generation is already running for this session.", |
| } |
| ), 400 |
|
|
| |
| active_session_id = session_id |
|
|
| |
| if need_force_takeover: |
| with consumption_monitor_lock: |
| last_frame_consumed_time = None |
|
|
| |
| mm.reset(history_length=history_length, smoothing_alpha=smoothing_alpha) |
| mm.start_generation(text, history_length=history_length) |
|
|
| |
| with consumption_monitor_lock: |
| last_frame_consumed_time = time.time() |
|
|
| |
| start_consumption_monitor() |
| print(f"[Session {session_id}] Consumption monitoring activated") |
|
|
| return jsonify( |
| { |
| "status": "success", |
| "message": f"Generation started with text: {text}, history_length: {history_length}", |
| "session_id": session_id, |
| } |
| ) |
| except Exception as e: |
| print(f"Error in start_generation: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/update_text", methods=["POST"]) |
| def update_text(): |
| """Update the generation text""" |
| try: |
| data = request.json |
| session_id = data.get("session_id") |
| text = data.get("text", "") |
|
|
| if not session_id: |
| return jsonify( |
| {"status": "error", "message": "session_id is required"} |
| ), 400 |
|
|
| |
| with session_lock: |
| if active_session_id != session_id: |
| return jsonify( |
| {"status": "error", "message": "Not the active session"} |
| ), 403 |
|
|
| if model_manager is None: |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 |
|
|
| model_manager.update_text(text) |
|
|
| return jsonify({"status": "success", "message": f"Text updated to: {text}"}) |
| except Exception as e: |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/pause", methods=["POST"]) |
| def pause_generation(): |
| """Pause generation (keeps state for resume)""" |
| try: |
| data = request.json if request.json else {} |
| session_id = data.get("session_id") |
|
|
| if not session_id: |
| return jsonify( |
| {"status": "error", "message": "session_id is required"} |
| ), 400 |
|
|
| |
| with session_lock: |
| if active_session_id != session_id: |
| return jsonify( |
| {"status": "error", "message": "Not the active session"} |
| ), 403 |
|
|
| if model_manager: |
| model_manager.pause_generation() |
|
|
| return jsonify({"status": "success", "message": "Generation paused"}) |
| except Exception as e: |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/resume", methods=["POST"]) |
| def resume_generation(): |
| """Resume generation from paused state""" |
| try: |
| global last_frame_consumed_time |
|
|
| data = request.json if request.json else {} |
| session_id = data.get("session_id") |
|
|
| if not session_id: |
| return jsonify( |
| {"status": "error", "message": "session_id is required"} |
| ), 400 |
|
|
| |
| with session_lock: |
| if active_session_id != session_id: |
| return jsonify( |
| {"status": "error", "message": "Not the active session"} |
| ), 403 |
|
|
| if model_manager is None: |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 |
|
|
| model_manager.resume_generation() |
|
|
| |
| with consumption_monitor_lock: |
| last_frame_consumed_time = time.time() |
|
|
| return jsonify({"status": "success", "message": "Generation resumed"}) |
| except Exception as e: |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/reset", methods=["POST"]) |
| def reset(): |
| """Reset generation state""" |
| try: |
| global active_session_id, last_frame_consumed_time |
|
|
| data = request.json if request.json else {} |
| session_id = data.get("session_id") |
| history_length = data.get("history_length") |
| smoothing_alpha = data.get("smoothing_alpha") |
|
|
| |
| if session_id: |
| with session_lock: |
| if active_session_id and active_session_id != session_id: |
| return jsonify( |
| {"status": "error", "message": "Not the active session"} |
| ), 403 |
|
|
| if model_manager: |
| model_manager.reset( |
| history_length=history_length, smoothing_alpha=smoothing_alpha |
| ) |
|
|
| |
| with session_lock: |
| if active_session_id == session_id or not session_id: |
| active_session_id = None |
|
|
| |
| with consumption_monitor_lock: |
| last_frame_consumed_time = None |
|
|
| print(f"[Session {session_id}] Reset complete, session cleared") |
|
|
| return jsonify( |
| { |
| "status": "success", |
| "message": "Reset complete", |
| } |
| ) |
| except Exception as e: |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/get_frame", methods=["GET"]) |
| def get_frame(): |
| """Get the next frame""" |
| try: |
| global last_frame_consumed_time |
|
|
| session_id = request.args.get("session_id") |
|
|
| if not session_id: |
| return jsonify( |
| {"status": "error", "message": "session_id is required"} |
| ), 400 |
|
|
| if model_manager is None: |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 |
|
|
| count = min(int(request.args.get("count", 8)), 20) |
|
|
| |
| with session_lock: |
| is_active = active_session_id == session_id |
|
|
| if is_active: |
| |
| frames = [] |
| for _ in range(count): |
| joints = model_manager.get_next_frame() |
| if joints is None: |
| break |
| frames.append(joints.tolist()) |
|
|
| if frames: |
| with consumption_monitor_lock: |
| last_frame_consumed_time = time.time() |
|
|
| return jsonify( |
| { |
| "status": "success", |
| "frames": frames, |
| "buffer_size": model_manager.frame_buffer.size(), |
| } |
| ) |
| else: |
| |
| after_id = int(request.args.get("after_id", 0)) |
| broadcast = model_manager.get_broadcast_frames(after_id, count) |
| if broadcast: |
| last_id = broadcast[-1][0] |
| frames = [joints.tolist() for _, joints in broadcast] |
| return jsonify( |
| { |
| "status": "success", |
| "frames": frames, |
| "last_id": last_id, |
| "buffer_size": model_manager.frame_buffer.size(), |
| } |
| ) |
|
|
| |
| return jsonify( |
| { |
| "status": "waiting", |
| "message": "No frame available yet", |
| "buffer_size": model_manager.frame_buffer.size(), |
| } |
| ) |
| except Exception as e: |
| print(f"Error in get_frame: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| @app.route("/api/status", methods=["GET"]) |
| def get_status(): |
| """Get generation status""" |
| try: |
| session_id = request.args.get("session_id") |
|
|
| with session_lock: |
| is_active_session = session_id and active_session_id == session_id |
| current_active_session = active_session_id |
|
|
| if model_manager is None: |
| return jsonify( |
| { |
| "initialized": False, |
| "buffer_size": 0, |
| "is_generating": False, |
| "is_active_session": is_active_session, |
| "active_session_id": current_active_session, |
| } |
| ) |
|
|
| status = model_manager.get_buffer_status() |
| status["initialized"] = True |
| status["is_active_session"] = is_active_session |
| status["active_session_id"] = current_active_session |
|
|
| return jsonify(status) |
| except Exception as e: |
| return jsonify({"status": "error", "message": str(e)}), 500 |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Flask server for real-time 3D motion generation" |
| ) |
| parser.add_argument( |
| "--model_name", |
| type=str, |
| default="ShandaAI/FloodDiffusionTiny", |
| help="HF Hub model name (default: ShandaAI/FloodDiffusionTiny)", |
| ) |
| parser.add_argument( |
| "--port", |
| type=int, |
| default=7860, |
| help="Port to run the server on (default: 7860)", |
| ) |
| args = parser.parse_args() |
|
|
| model_name_global = args.model_name |
|
|
| |
| print(f"Loading model: {model_name_global}") |
| init_model() |
|
|
| print("Starting Flask server...") |
| app.run(host="0.0.0.0", port=args.port, debug=False, threaded=True) |
|
|