pythoncallbackdropdownplotly-dash

How to call a different set of inputs based on the drop-down option selected in dash?


I am building an image-processing app in dash, where the user uploads an image, selects a smoothing filter and then chooses between k-means clustering and thresholding to identify the foreground object from the background.

As shown below, I have a callback with a filters-drop down (which works fine as both filters take the same type of input - in this case the kernel-height-slider and the kernel-width-slider.

    import numpy as np
    import cv2 as cv
    from dash import Dash, dcc, html
    from dash.dependencies import Input, Output
    import dash_daq as daq
    import io
    import os
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from PIL import Image
    import base64
    import plotly.express as px
    import plotly.graph_objects as go
    
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    
    app = Dash(__name__, external_stylesheets=external_stylesheets, suppress_callback_exceptions=True)
    
    img_path = "assets/lyco_WOL_00140.tif"
    
    def preprocessing(img_path):
        # Image processing
        img = cv.imread(img_path)
        imgYCC = cv.cvtColor(img, cv.COLOR_BGR2YCR_CB)
        gray = imgYCC[:,:,0] 
        return gray
    
    preprocessed_img = preprocessing(img_path)
    
    filters = {
        "gaussian": lambda img, kh, kw: cv.GaussianBlur(img, (kh,kw), 0),
        "2dconv": lambda img, kh, kw: cv.filter2D(img, -1, np.ones((kh,kw), np.float32)/25)
    }
    
    def otsu_thresholding(img, min_thr, max_thr):
        return cv.threshold(img, min_thr, max_thr, cv.THRESH_OTSU)
    
    def kmeans_clustering(img, K, n_iter, accuracy):
        Z = img.reshape((-1,2))
        Z = np.float32(Z)
        criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, n_iter, accuracy)
        _, label, center = cv.kmeans(Z, K, None, criteria, n_iter, cv.KMEANS_RANDOM_CENTERS)
        center = np.uint8(center)
        res = center[label.flatten()]
        res2 = res.reshape((img.shape))  
        return res2  
    
    detectors = {
        "otsu": otsu_thresholding,
        "kmeans": kmeans_clustering
    }
    
    @app.callback(
        Output('smoothed-img','src'),    
        [Input('filters-dropdown', 'value'),
         Input('kernel-height-slider', 'value'),
         Input('kernel-width-slider', 'value'),
         Input('detectors-radio', 'value'),
         Input('otsu-min-threshold', 'value'),
         Input('otsu-max-threshold', 'value'),     
         Input('kmeans-kvalue-input', 'value'),
         Input('kmeans-niter-input', 'value'),   
         Input('kmeans-accuracy-input', 'value'),        
         ]
    )
    
    # from PreprocessingClass import PreprocessClass
    
    def update_overlay_base64(selected_filter, kernel_height, kernel_width, 
                              selected_detector, otsu_min_thr, otsu_max_thr, 
                              kmeans_k, kmeans_niter, kmeans_accuracy):
        if not selected_filter:
            return None
        
        # prepro_class = PreprocessClass(selected_filter = selected_filter)
    
        # if selected_filter == "kmeans":
        #     prepro_class.kmeans_clustering(img, K, n_iter, accuracy)
    
        
        filtered_img = filters[selected_filter](preprocessed_img, kernel_height, kernel_width)
    
        if selected_detector == "otsu" and otsu_min_thr is not None and otsu_max_thr is not None:
            processed_img = otsu_thresholding(filtered_img, otsu_min_thr, otsu_max_thr)[1]
            
        elif selected_detector == "kmeans" and None not in (kmeans_k, kmeans_niter, kmeans_accuracy):
            processed_img = kmeans_clustering(filtered_img, kmeans_k, kmeans_niter, kmeans_accuracy)
    
        edge_img = cv.Canny(processed_img, 10, 150)
        edge_img = np.array(edge_img, dtype='float64')
        edge_img[edge_img==0] = np.nan
        rows, cols = np.where(edge_img == 255)
        min_row, max_row = rows.min(), rows.max()
        min_col, max_col = cols.min(), cols.max()
    
        # Plot with matplotlib
        fig, ax = plt.subplots()
        ax.imshow(processed_img, cmap='Spectral_r')
        ax.imshow(edge_img, cmap='autumn', alpha=0.5)
        plt.gca().add_patch(
            plt.Rectangle((min_col-1, min_row-1), 
                            max_col - min_col + 2, 
                            max_row - min_row + 2, 
                            edgecolor='cyan', 
                            facecolor='none', 
                            linewidth=1)
        )    
        ax.axis('off')    
    
        # Save plot to buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        plt.close(fig)
        buf.seek(0)
     
        # Encode to base64
        encoded = base64.b64encode(buf.read()).decode()
        return f"data:image/png;base64,{encoded}"
    
    # pil_img = Image.fromarray(edge_otsu_img)
    # buffer = io.BytesIO()
    # pil_img.save(buffer, format="PNG")
    # encoded_image = base64.b64encode(buffer.getvalue()).decode()
    
    
    
    
    # Dash app layout
    app.layout = html.Div(
        children=[
        html.H1('Test Flame Analyser v0.3'),
        html.Div(
            children=[
            dcc.Dropdown(
                id='filters-dropdown',
                options = [{
                    'label':'Gaussian', 'value':'gaussian'},
                    {'label':'2D Convolution', 'value':'2dconv'}]
            ),
            html.Img(id='smoothed-img', style={"maxWidth": "100%"}),
                daq.Slider(id='kernel-height-slider',
                    min=1, max=25, value=5,
                    marks = {'5':'5', '15':'15', '25':'25'},
                    handleLabel={"showCurrentValue": True,"label": "Height"},
                    step=2
                    ),     
                daq.Slider(id='kernel-width-slider',
                    min=1, max=25, value=5,
                    marks = {'5':'5', '15':'15' ,'25':'25'},
                    handleLabel={"showCurrentValue": True,"label": "Width"},
                    step=2
                    ),
            dcc.RadioItems(
                id='detectors-radio',
                options=[
                    {'label': 'Otsu Thresholding', 'value': 'otsu'},
                    {'label': 'K-means Clustering', 'value': 'kmeans'}
                ],
                value='otsu'
            ),
            html.Div(id='detector-params'),
            html.Div(id='output')                               
            ])
    
        # dcc.Graph(id='figure', figure=fig),
    ])
    
    @app.callback(
        Output('detector-params', 'children'),
        [Input('detectors-radio', 'value')]
    )
    def render_detector_params(detector_name):
        if detector_name == 'otsu':
            return html.Div([
                dcc.Input(id='otsu-min-threshold', type='number', value=0, placeholder="Min Threshold"),
                dcc.Input(id='otsu-max-threshold', type='number', value=255, placeholder="Max Threshold"),
            ])
        elif detector_name == 'kmeans':
            return html.Div([
                #dcc.Dropdown(id='kmeans-kvalue-dropdown', options=[{'label':'2', 'value':2}, {'label':'3', 'value':3}], value=2),
                dcc.Input(id='kmeans-kvalue-input', type='number', value=2, placeholder="K (# clusters)"),
                dcc.Input(id='kmeans-niter-input', type='number', value=10, placeholder="Max Iterations"),
                dcc.Input(id='kmeans-accuracy-input', type='number', value=1.0, placeholder="Accuracy (epsilon)", step=0.01),
            ])
    
    
    
    # def update_figure(img):
    # fig = px.imshow(img)
    # fig.show()
    #     return fig
    
    
    if __name__ == "__main__":
        app.run(debug=True)

However, when choosing from the detectors-dropdown, depending on which one is selected the inputs for the image processing function will either be otsu-min-threshold and otsu-max-threshold (if otsu is selected) OR kmeans-kvalue-input, kmeans-niter-input and kmeans-accuracy-input (if kmeans is selected).

The issue is that the callback needs all the Inputs to work properly and throws errors as inevitably, one set of input is not available when the other option is selected. As a result I get the error “A nonexistent object was used in an Input of a Dash callback”. Is there a way to work around this? Can I do conditional callbacks or something similar?

I found something on this plotly forum page which I am about to try but not sure if it will work.

Thank you!


Solution

  • You need to keep all the components in the layout, ie. instead of adding/removing otsu/kmeans parameters, provide all of them initially in the layout, wrapping each group of parameters in its own div :

    html.Div(id='detector-params', children=[
        html.Div(id='otsu-params', children=[
            dcc.Input(id='otsu-min-threshold', type='number', value=0, placeholder="Min Threshold"),
            dcc.Input(id='otsu-max-threshold', type='number', value=255, placeholder="Max Threshold"),
        ]),
        html.Div(id='kmeans-params', children=[
            dcc.Input(id='kmeans-kvalue-input', type='number', value=2, placeholder="K (# clusters)"),
            dcc.Input(id='kmeans-niter-input', type='number', value=10, placeholder="Max Iterations"),
            dcc.Input(id='kmeans-accuracy-input', type='number', value=1.0, placeholder="Accuracy (epsilon)", step=0.01),
        ])
    ])
    

    And in the render_detector_params callback, just show/hide them using using their hidden property given the value of the detectors-radio input :

    @app.callback(
        Output('otsu-params', 'hidden'),
        Output('kmeans-params', 'hidden'),
        Input('detectors-radio', 'value'),
        prevent_initial_call=False
    )
    def render_detector_params(detector_name):
        if detector_name == 'otsu':
            return False, True
        elif detector_name == 'kmeans':
            return True, False