| | from flask import Flask, request, jsonify, send_file |
| | from flask_cors import CORS |
| | import os |
| | from huggingface_hub import InferenceClient |
| | from io import BytesIO |
| | from PIL import Image |
| |
|
| | |
| | app = Flask(__name__) |
| | CORS(app) |
| |
|
| | |
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| | client = InferenceClient(token=HF_TOKEN) |
| |
|
| | |
| | NEGATIVE_PROMPT_FINGERS = """missing fingers, extra fingers, elongated fingers, fused fingers, |
| | mutated fingers, poorly drawn fingers, disfigured fingers, |
| | too many fingers, deformed hands, extra hands, malformed hands, |
| | blurry hands, disproportionate fingers""" |
| |
|
| | @app.route('/') |
| | def home(): |
| | return "Welcome to the AI Image Generator with NSFW Detection!" |
| |
|
| | |
| | def is_nsfw_image(image): |
| | try: |
| | |
| | img_byte_arr = BytesIO() |
| | image.save(img_byte_arr, format='PNG') |
| | img_byte_arr.seek(0) |
| |
|
| | |
| | result = client.image_classification(model="Falconsai/nsfw_image_detection", inputs=img_byte_arr.getvalue()) |
| | |
| | |
| | for item in result: |
| | if item['label'].lower() == 'nsfw' and item['score'] > 0.5: |
| | return True |
| | return False |
| | except Exception as e: |
| | print(f"NSFW detection error: {e}") |
| | return False |
| |
|
| | |
| | def generate_image(prompt, negative_prompt=None, height=512, width=512, model="stabilityai/stable-diffusion-2-1", num_inference_steps=50, guidance_scale=7.5, seed=None): |
| | try: |
| | |
| | image = client.text_to_image( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt or NEGATIVE_PROMPT_FINGERS, |
| | height=height, |
| | width=width, |
| | model=model, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | seed=seed |
| | ) |
| | return image |
| | except Exception as e: |
| | print(f"Error generating image: {e}") |
| | return None |
| |
|
| | |
| | @app.route('/generate_image', methods=['POST']) |
| | def generate_api(): |
| | data = request.get_json() |
| |
|
| | |
| | prompt = data.get('prompt', '') |
| | negative_prompt = data.get('negative_prompt', None) |
| | height = data.get('height', 512) |
| | width = data.get('width', 512) |
| | num_inference_steps = data.get('num_inference_steps', 50) |
| | guidance_scale = data.get('guidance_scale', 7.5) |
| | model_name = data.get('model', 'stabilityai/stable-diffusion-2-1') |
| | seed = data.get('seed', None) |
| |
|
| | if not prompt: |
| | return jsonify({"error": "Prompt is required"}), 400 |
| |
|
| | try: |
| | |
| | image = generate_image(prompt, negative_prompt, height, width, model_name, num_inference_steps, guidance_scale, seed) |
| |
|
| | if image: |
| | |
| | if is_nsfw_image(image): |
| | return send_file( |
| | "nsfw.jpg", |
| | mimetype='image/jpeg', |
| | as_attachment=False, |
| | download_name='nsfw.jpg' |
| | ) |
| |
|
| | |
| | img_byte_arr = BytesIO() |
| | image.save(img_byte_arr, format='PNG') |
| | img_byte_arr.seek(0) |
| |
|
| | |
| | return send_file( |
| | img_byte_arr, |
| | mimetype='image/png', |
| | as_attachment=False, |
| | download_name='generated_image.png' |
| | ) |
| | else: |
| | return jsonify({"error": "Failed to generate image"}), 500 |
| | except Exception as e: |
| | print(f"Error in generate_api: {e}") |
| | return jsonify({"error": str(e)}), 500 |
| |
|
| | |
| | if __name__ == '__main__': |
| | app.run(host='0.0.0.0', port=7860) |