Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -21,7 +21,10 @@ EMBEDDING_DIM = 50 | |
| 21 | 
             
            IMAGE_SIZE = 160
         | 
| 22 | 
             
            BATCH_SIZE = 64
         | 
| 23 |  | 
| 24 | 
            -
             | 
|  | |
|  | |
|  | |
| 25 | 
             
                # Load dataset
         | 
| 26 | 
             
                dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
         | 
| 27 | 
             
                dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
         | 
| @@ -29,13 +32,17 @@ def load_and_preprocess_data(subset_size=10000): | |
| 29 | 
             
                # Filter out NSFW content
         | 
| 30 | 
             
                dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
         | 
| 31 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 | 
             
                return dataset_subset
         | 
| 33 |  | 
| 34 | 
             
            def process_text_data(dataset_subset):
         | 
| 35 | 
            -
                # Combine prompt and negative prompt
         | 
| 36 | 
            -
                text_data = [ | 
| 37 |  | 
| 38 | 
            -
                # Tokenize text
         | 
| 39 | 
             
                tokenizer = Tokenizer(num_words=10000)
         | 
| 40 | 
             
                tokenizer.fit_on_texts(text_data)
         | 
| 41 | 
             
                sequences = tokenizer.texts_to_sequences(text_data)
         | 
| @@ -43,6 +50,14 @@ def process_text_data(dataset_subset): | |
| 43 |  | 
| 44 | 
             
                return text_data_padded, tokenizer
         | 
| 45 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 46 | 
             
            def process_image_data(dataset_subset):
         | 
| 47 | 
             
                image_dir = 'civitai_images'
         | 
| 48 | 
             
                os.makedirs(image_dir, exist_ok=True)
         | 
| @@ -55,7 +70,6 @@ def process_image_data(dataset_subset): | |
| 55 | 
             
                    img_path = os.path.join(image_dir, os.path.basename(img_url))
         | 
| 56 |  | 
| 57 | 
             
                    try:
         | 
| 58 | 
            -
                        # Download and save image
         | 
| 59 | 
             
                        response = requests.get(img_url, timeout=5)
         | 
| 60 | 
             
                        response.raise_for_status()
         | 
| 61 |  | 
| @@ -65,7 +79,6 @@ def process_image_data(dataset_subset): | |
| 65 | 
             
                        with open(img_path, 'wb') as f:
         | 
| 66 | 
             
                            f.write(response.content)
         | 
| 67 |  | 
| 68 | 
            -
                        # Load and preprocess image
         | 
| 69 | 
             
                        img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
         | 
| 70 | 
             
                        img_array = image.img_to_array(img)
         | 
| 71 | 
             
                        img_array = preprocess_input(img_array)
         | 
| @@ -79,26 +92,21 @@ def process_image_data(dataset_subset): | |
| 79 | 
             
                return np.array(image_data), valid_indices
         | 
| 80 |  | 
| 81 | 
             
            def create_multimodal_model(num_words, num_classes):
         | 
| 82 | 
            -
                # Image input branch (CNN)
         | 
| 83 | 
             
                image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
         | 
| 84 | 
             
                cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
         | 
| 85 |  | 
| 86 | 
            -
                # Freeze most of the ResNet50 layers
         | 
| 87 | 
             
                for layer in cnn_base.layers[:-10]:
         | 
| 88 | 
             
                    layer.trainable = False
         | 
| 89 |  | 
| 90 | 
             
                cnn_features = cnn_base(image_input)
         | 
| 91 |  | 
| 92 | 
            -
                # Text input branch
         | 
| 93 | 
             
                text_input = Input(shape=(MAX_TEXT_LENGTH,))
         | 
| 94 | 
             
                embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
         | 
| 95 | 
             
                flatten_text = Flatten()(embedding_layer)
         | 
| 96 | 
             
                text_features = Dense(128, activation='relu')(flatten_text)
         | 
| 97 |  | 
| 98 | 
            -
                # Combine features
         | 
| 99 | 
             
                combined = Concatenate()([cnn_features, text_features])
         | 
| 100 |  | 
| 101 | 
            -
                # Simplified fully connected layers
         | 
| 102 | 
             
                x = Dense(256, activation='relu')(combined)
         | 
| 103 | 
             
                output = Dense(num_classes, activation='softmax')(x)
         | 
| 104 |  | 
| @@ -106,24 +114,18 @@ def create_multimodal_model(num_words, num_classes): | |
| 106 | 
             
                return model
         | 
| 107 |  | 
| 108 | 
             
            def train_model():
         | 
| 109 | 
            -
                # Load and preprocess data
         | 
| 110 | 
             
                dataset_subset = load_and_preprocess_data()
         | 
| 111 |  | 
| 112 | 
            -
                # Process text data
         | 
| 113 | 
             
                text_data_padded, tokenizer = process_text_data(dataset_subset)
         | 
| 114 |  | 
| 115 | 
            -
                # Process image data
         | 
| 116 | 
             
                image_data, valid_indices = process_image_data(dataset_subset)
         | 
| 117 |  | 
| 118 | 
            -
                # Get valid text data and labels
         | 
| 119 | 
             
                text_data_padded = text_data_padded[valid_indices]
         | 
| 120 | 
             
                model_names = [dataset_subset[i]['Model'] for i in valid_indices]
         | 
| 121 |  | 
| 122 | 
            -
                # Encode labels
         | 
| 123 | 
             
                label_encoder = LabelEncoder()
         | 
| 124 | 
             
                encoded_labels = label_encoder.fit_transform(model_names)
         | 
| 125 |  | 
| 126 | 
            -
                # Create and compile model
         | 
| 127 | 
             
                model = create_multimodal_model(
         | 
| 128 | 
             
                    num_words=10000,
         | 
| 129 | 
             
                    num_classes=len(label_encoder.classes_)
         | 
| @@ -135,7 +137,6 @@ def train_model(): | |
| 135 | 
             
                    metrics=['accuracy']
         | 
| 136 | 
             
                )
         | 
| 137 |  | 
| 138 | 
            -
                # Train model
         | 
| 139 | 
             
                history = model.fit(
         | 
| 140 | 
             
                    [image_data, text_data_padded],
         | 
| 141 | 
             
                    encoded_labels,
         | 
| @@ -144,68 +145,75 @@ def train_model(): | |
| 144 | 
             
                    validation_split=0.2
         | 
| 145 | 
             
                )
         | 
| 146 |  | 
| 147 | 
            -
                 | 
| 148 | 
            -
                model.save('multimodal_model.keras')  # Changed from 'multimodal_model'
         | 
| 149 | 
             
                joblib.dump(tokenizer, 'tokenizer.pkl')
         | 
| 150 | 
             
                joblib.dump(label_encoder, 'label_encoder.pkl')
         | 
| 151 |  | 
|  | |
|  | |
|  | |
| 152 | 
             
                return model, tokenizer, label_encoder
         | 
| 153 |  | 
| 154 | 
            -
            def get_recommendations(image_input,  | 
| 155 | 
            -
                # Preprocess image
         | 
| 156 | 
             
                img_array = image.img_to_array(image_input)
         | 
| 157 | 
             
                img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
         | 
| 158 | 
             
                img_array = preprocess_input(img_array)
         | 
| 159 | 
             
                img_array = np.expand_dims(img_array, axis=0)
         | 
| 160 |  | 
| 161 | 
            -
                #  | 
| 162 | 
            -
                text_sequence = tokenizer.texts_to_sequences([ | 
| 163 | 
             
                text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
         | 
| 164 |  | 
| 165 | 
            -
                # Get predictions
         | 
| 166 | 
             
                predictions = model.predict([img_array, text_padded])
         | 
| 167 | 
             
                top_indices = np.argsort(predictions[0])[-top_k:][::-1]
         | 
| 168 |  | 
| 169 | 
            -
                 | 
| 170 | 
            -
                 | 
| 171 | 
            -
                     | 
| 172 | 
            -
                     | 
| 173 | 
            -
             | 
|  | |
|  | |
|  | |
| 174 |  | 
| 175 | 
             
                return recommendations
         | 
| 176 |  | 
| 177 | 
             
            def create_gradio_interface():
         | 
| 178 | 
            -
                 | 
| 179 | 
            -
                model = tf.keras.models.load_model('multimodal_model.keras')  # Changed from 'multimodal_model'
         | 
| 180 | 
             
                tokenizer = joblib.load('tokenizer.pkl')
         | 
| 181 | 
             
                label_encoder = joblib.load('label_encoder.pkl')
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 182 |  | 
| 183 | 
            -
                 | 
| 184 | 
            -
                    recommendations = get_recommendations(img, text, model, tokenizer, label_encoder)
         | 
| 185 | 
            -
                    return "\n".join([f"Model: {name}, Confidence: {conf:.2f}" for name, conf in recommendations])
         | 
| 186 |  | 
| 187 | 
             
                interface = gr.Interface(
         | 
| 188 | 
             
                    fn=predict,
         | 
| 189 | 
            -
                    inputs= | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
                     | 
| 193 | 
            -
                    outputs=gr.Textbox(label="Recommended Models"),
         | 
| 194 | 
            -
                    title="Multimodal Model Recommendation System",
         | 
| 195 | 
            -
                    description="Upload an image and enter a prompt to get model recommendations"
         | 
| 196 | 
             
                )
         | 
| 197 |  | 
| 198 | 
             
                return interface
         | 
| 199 |  | 
| 200 | 
             
            if __name__ == "__main__":
         | 
| 201 | 
            -
                 | 
| 202 | 
            -
                if not os.path.exists('multimodal_model.keras'):  # Changed from 'multimodal_model'
         | 
| 203 | 
             
                    print("Training new model...")
         | 
| 204 | 
             
                    model, tokenizer, label_encoder = train_model()
         | 
| 205 | 
             
                    print("Training completed!")
         | 
| 206 | 
             
                else:
         | 
| 207 | 
             
                    print("Loading existing model...")
         | 
| 208 |  | 
| 209 | 
            -
                # Launch Gradio interface
         | 
| 210 | 
             
                interface = create_gradio_interface()
         | 
| 211 | 
             
                interface.launch()
         | 
|  | |
| 21 | 
             
            IMAGE_SIZE = 160
         | 
| 22 | 
             
            BATCH_SIZE = 64
         | 
| 23 |  | 
| 24 | 
            +
            # Store model examples
         | 
| 25 | 
            +
            model_examples = {}
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def load_and_preprocess_data(subset_size=20000):
         | 
| 28 | 
             
                # Load dataset
         | 
| 29 | 
             
                dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
         | 
| 30 | 
             
                dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
         | 
|  | |
| 32 | 
             
                # Filter out NSFW content
         | 
| 33 | 
             
                dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
         | 
| 34 |  | 
| 35 | 
            +
                # Store example images for each model
         | 
| 36 | 
            +
                for item in dataset_subset:
         | 
| 37 | 
            +
                    if item['Model'] not in model_examples:
         | 
| 38 | 
            +
                        model_examples[item['Model']] = item['url']
         | 
| 39 | 
            +
                
         | 
| 40 | 
             
                return dataset_subset
         | 
| 41 |  | 
| 42 | 
             
            def process_text_data(dataset_subset):
         | 
| 43 | 
            +
                # Combine prompt and negative prompt without user input
         | 
| 44 | 
            +
                text_data = ["default prompt" for _ in dataset_subset]
         | 
| 45 |  | 
|  | |
| 46 | 
             
                tokenizer = Tokenizer(num_words=10000)
         | 
| 47 | 
             
                tokenizer.fit_on_texts(text_data)
         | 
| 48 | 
             
                sequences = tokenizer.texts_to_sequences(text_data)
         | 
|  | |
| 50 |  | 
| 51 | 
             
                return text_data_padded, tokenizer
         | 
| 52 |  | 
| 53 | 
            +
            def download_image(url):
         | 
| 54 | 
            +
                try:
         | 
| 55 | 
            +
                    response = requests.get(url, timeout=5)
         | 
| 56 | 
            +
                    response.raise_for_status()
         | 
| 57 | 
            +
                    return Image.open(requests.get(url, stream=True).raw)
         | 
| 58 | 
            +
                except:
         | 
| 59 | 
            +
                    return None
         | 
| 60 | 
            +
             | 
| 61 | 
             
            def process_image_data(dataset_subset):
         | 
| 62 | 
             
                image_dir = 'civitai_images'
         | 
| 63 | 
             
                os.makedirs(image_dir, exist_ok=True)
         | 
|  | |
| 70 | 
             
                    img_path = os.path.join(image_dir, os.path.basename(img_url))
         | 
| 71 |  | 
| 72 | 
             
                    try:
         | 
|  | |
| 73 | 
             
                        response = requests.get(img_url, timeout=5)
         | 
| 74 | 
             
                        response.raise_for_status()
         | 
| 75 |  | 
|  | |
| 79 | 
             
                        with open(img_path, 'wb') as f:
         | 
| 80 | 
             
                            f.write(response.content)
         | 
| 81 |  | 
|  | |
| 82 | 
             
                        img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
         | 
| 83 | 
             
                        img_array = image.img_to_array(img)
         | 
| 84 | 
             
                        img_array = preprocess_input(img_array)
         | 
|  | |
| 92 | 
             
                return np.array(image_data), valid_indices
         | 
| 93 |  | 
| 94 | 
             
            def create_multimodal_model(num_words, num_classes):
         | 
|  | |
| 95 | 
             
                image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
         | 
| 96 | 
             
                cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
         | 
| 97 |  | 
|  | |
| 98 | 
             
                for layer in cnn_base.layers[:-10]:
         | 
| 99 | 
             
                    layer.trainable = False
         | 
| 100 |  | 
| 101 | 
             
                cnn_features = cnn_base(image_input)
         | 
| 102 |  | 
|  | |
| 103 | 
             
                text_input = Input(shape=(MAX_TEXT_LENGTH,))
         | 
| 104 | 
             
                embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
         | 
| 105 | 
             
                flatten_text = Flatten()(embedding_layer)
         | 
| 106 | 
             
                text_features = Dense(128, activation='relu')(flatten_text)
         | 
| 107 |  | 
|  | |
| 108 | 
             
                combined = Concatenate()([cnn_features, text_features])
         | 
| 109 |  | 
|  | |
| 110 | 
             
                x = Dense(256, activation='relu')(combined)
         | 
| 111 | 
             
                output = Dense(num_classes, activation='softmax')(x)
         | 
| 112 |  | 
|  | |
| 114 | 
             
                return model
         | 
| 115 |  | 
| 116 | 
             
            def train_model():
         | 
|  | |
| 117 | 
             
                dataset_subset = load_and_preprocess_data()
         | 
| 118 |  | 
|  | |
| 119 | 
             
                text_data_padded, tokenizer = process_text_data(dataset_subset)
         | 
| 120 |  | 
|  | |
| 121 | 
             
                image_data, valid_indices = process_image_data(dataset_subset)
         | 
| 122 |  | 
|  | |
| 123 | 
             
                text_data_padded = text_data_padded[valid_indices]
         | 
| 124 | 
             
                model_names = [dataset_subset[i]['Model'] for i in valid_indices]
         | 
| 125 |  | 
|  | |
| 126 | 
             
                label_encoder = LabelEncoder()
         | 
| 127 | 
             
                encoded_labels = label_encoder.fit_transform(model_names)
         | 
| 128 |  | 
|  | |
| 129 | 
             
                model = create_multimodal_model(
         | 
| 130 | 
             
                    num_words=10000,
         | 
| 131 | 
             
                    num_classes=len(label_encoder.classes_)
         | 
|  | |
| 137 | 
             
                    metrics=['accuracy']
         | 
| 138 | 
             
                )
         | 
| 139 |  | 
|  | |
| 140 | 
             
                history = model.fit(
         | 
| 141 | 
             
                    [image_data, text_data_padded],
         | 
| 142 | 
             
                    encoded_labels,
         | 
|  | |
| 145 | 
             
                    validation_split=0.2
         | 
| 146 | 
             
                )
         | 
| 147 |  | 
| 148 | 
            +
                model.save('multimodal_model.keras')
         | 
|  | |
| 149 | 
             
                joblib.dump(tokenizer, 'tokenizer.pkl')
         | 
| 150 | 
             
                joblib.dump(label_encoder, 'label_encoder.pkl')
         | 
| 151 |  | 
| 152 | 
            +
                # Save model examples
         | 
| 153 | 
            +
                joblib.dump(model_examples, 'model_examples.pkl')
         | 
| 154 | 
            +
                
         | 
| 155 | 
             
                return model, tokenizer, label_encoder
         | 
| 156 |  | 
| 157 | 
            +
            def get_recommendations(image_input, model, tokenizer, label_encoder, top_k=5):
         | 
|  | |
| 158 | 
             
                img_array = image.img_to_array(image_input)
         | 
| 159 | 
             
                img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
         | 
| 160 | 
             
                img_array = preprocess_input(img_array)
         | 
| 161 | 
             
                img_array = np.expand_dims(img_array, axis=0)
         | 
| 162 |  | 
| 163 | 
            +
                # Use default text input
         | 
| 164 | 
            +
                text_sequence = tokenizer.texts_to_sequences(["default prompt"])
         | 
| 165 | 
             
                text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
         | 
| 166 |  | 
|  | |
| 167 | 
             
                predictions = model.predict([img_array, text_padded])
         | 
| 168 | 
             
                top_indices = np.argsort(predictions[0])[-top_k:][::-1]
         | 
| 169 |  | 
| 170 | 
            +
                recommendations = []
         | 
| 171 | 
            +
                for idx in top_indices:
         | 
| 172 | 
            +
                    model_name = label_encoder.inverse_transform([idx])[0]
         | 
| 173 | 
            +
                    confidence = predictions[0][idx]
         | 
| 174 | 
            +
                    if model_name in model_examples:
         | 
| 175 | 
            +
                        example_image = download_image(model_examples[model_name])
         | 
| 176 | 
            +
                        if example_image:
         | 
| 177 | 
            +
                            recommendations.append((model_name, confidence, example_image))
         | 
| 178 |  | 
| 179 | 
             
                return recommendations
         | 
| 180 |  | 
| 181 | 
             
            def create_gradio_interface():
         | 
| 182 | 
            +
                model = tf.keras.models.load_model('multimodal_model.keras')
         | 
|  | |
| 183 | 
             
                tokenizer = joblib.load('tokenizer.pkl')
         | 
| 184 | 
             
                label_encoder = joblib.load('label_encoder.pkl')
         | 
| 185 | 
            +
                model_examples_data = joblib.load('model_examples.pkl')
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                def predict(img):
         | 
| 188 | 
            +
                    recommendations = get_recommendations(img, model, tokenizer, label_encoder)
         | 
| 189 | 
            +
                    result_text = ""
         | 
| 190 | 
            +
                    result_images = []
         | 
| 191 | 
            +
                    
         | 
| 192 | 
            +
                    for model_name, conf, example_img in recommendations:
         | 
| 193 | 
            +
                        result_text += f"Model: {model_name}\n"
         | 
| 194 | 
            +
                        result_images.append(example_img)
         | 
| 195 | 
            +
                    
         | 
| 196 | 
            +
                    return [result_text] + result_images
         | 
| 197 |  | 
| 198 | 
            +
                outputs = [gr.Textbox(label="Recommended Models")] + [gr.Image(label=f"Example {i+1}") for i in range(5)]
         | 
|  | |
|  | |
| 199 |  | 
| 200 | 
             
                interface = gr.Interface(
         | 
| 201 | 
             
                    fn=predict,
         | 
| 202 | 
            +
                    inputs=gr.Image(type="pil", label="Upload Image"),
         | 
| 203 | 
            +
                    outputs=outputs,
         | 
| 204 | 
            +
                    title="AI Model Recommendation System",
         | 
| 205 | 
            +
                    description="Upload an image to get model recommendations with examples"
         | 
|  | |
|  | |
|  | |
| 206 | 
             
                )
         | 
| 207 |  | 
| 208 | 
             
                return interface
         | 
| 209 |  | 
| 210 | 
             
            if __name__ == "__main__":
         | 
| 211 | 
            +
                if not os.path.exists('multimodal_model.keras'):
         | 
|  | |
| 212 | 
             
                    print("Training new model...")
         | 
| 213 | 
             
                    model, tokenizer, label_encoder = train_model()
         | 
| 214 | 
             
                    print("Training completed!")
         | 
| 215 | 
             
                else:
         | 
| 216 | 
             
                    print("Loading existing model...")
         | 
| 217 |  | 
|  | |
| 218 | 
             
                interface = create_gradio_interface()
         | 
| 219 | 
             
                interface.launch()
         |