mdztxi2 / app.py
Geek7's picture
Update app.py
e826106 verified
raw
history blame
4.15 kB
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
from huggingface_hub import InferenceClient
from io import BytesIO
from PIL import Image
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Initialize the InferenceClient with Hugging Face token
HF_TOKEN = os.environ.get("HF_TOKEN") # Set your Hugging Face token in environment variables
client = InferenceClient(token=HF_TOKEN)
# Hardcoded negative prompt
NEGATIVE_PROMPT_FINGERS = """missing fingers, extra fingers, elongated fingers, fused fingers,
mutated fingers, poorly drawn fingers, disfigured fingers,
too many fingers, deformed hands, extra hands, malformed hands,
blurry hands, disproportionate fingers"""
@app.route('/')
def home():
return "Welcome to the AI Image Generator with NSFW Detection!"
# Function for NSFW detection
def is_nsfw_image(image):
try:
# Convert the image to bytes
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
# Send the image to Hugging Face for NSFW classification
result = client.image_classification(model="Falconsai/nsfw_image_detection", inputs=img_byte_arr.getvalue())
# Check if any prediction is NSFW with high confidence
for item in result:
if item['label'].lower() == 'nsfw' and item['score'] > 0.5:
return True
return False
except Exception as e:
print(f"NSFW detection error: {e}")
return False
# Function to generate an image
def generate_image(prompt, negative_prompt=None, height=512, width=512, model="stabilityai/stable-diffusion-2-1", num_inference_steps=50, guidance_scale=7.5, seed=None):
try:
# Generate the image using Hugging Face's API
image = client.text_to_image(
prompt=prompt,
negative_prompt=negative_prompt or NEGATIVE_PROMPT_FINGERS,
height=height,
width=width,
model=model,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed
)
return image
except Exception as e:
print(f"Error generating image: {e}")
return None
# Flask route for image generation API
@app.route('/generate_image', methods=['POST'])
def generate_api():
data = request.get_json()
# Extract required fields from the request
prompt = data.get('prompt', '')
negative_prompt = data.get('negative_prompt', None)
height = data.get('height', 512)
width = data.get('width', 512)
num_inference_steps = data.get('num_inference_steps', 50)
guidance_scale = data.get('guidance_scale', 7.5)
model_name = data.get('model', 'stabilityai/stable-diffusion-2-1')
seed = data.get('seed', None)
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
try:
# Generate the image
image = generate_image(prompt, negative_prompt, height, width, model_name, num_inference_steps, guidance_scale, seed)
if image:
# Check for NSFW content
if is_nsfw_image(image):
return send_file(
"nsfw.jpg", # Path to your predefined NSFW placeholder image
mimetype='image/jpeg',
as_attachment=False,
download_name='nsfw.jpg'
)
# Save the image to a BytesIO object
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
# Send the generated image
return send_file(
img_byte_arr,
mimetype='image/png',
as_attachment=False,
download_name='generated_image.png'
)
else:
return jsonify({"error": "Failed to generate image"}), 500
except Exception as e:
print(f"Error in generate_api: {e}")
return jsonify({"error": str(e)}), 500
# Run the Flask app
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)