Skip to content

Commit 797621b

Browse files
aliabid94Ali Abidgradio-pr-bot
authored
Improved plot rendering (#8580)
* changes * add changeset * add changeset * add changeset * changes * changes * restore altair * changes * changes * changes * changes * changes * changes * Update twenty-jokes-argue.md * changes * chanegs * changes * changes * changes * changes --------- Co-authored-by: Ali Abid <[email protected]> Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 1e61644 commit 797621b

21 files changed

+562
-352
lines changed

.changeset/twenty-jokes-argue.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"@gradio/plot": patch
3+
"gradio": patch
4+
---
5+
6+
feat:Improved plot rendering to thematically match
7+
highlight:Expect visual changes in gr.Plot, gr.BarPlot, gr.LinePlot, gr.ScatterPlot, including changes to color and width sizing.

demo/native_plots/bar_plot_demo.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def bar_plot_fn(display):
9898

9999

100100
with gr.Blocks() as bar_plot:
101-
with gr.Row():
102-
with gr.Column():
103-
display = gr.Dropdown(
104-
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
105-
value="simple",
106-
label="Type of Bar Plot"
107-
)
108-
with gr.Column():
109-
plot = gr.BarPlot(show_label=False, show_actions_button=True)
101+
display = gr.Dropdown(
102+
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
103+
value="simple",
104+
label="Type of Bar Plot"
105+
)
106+
plot = gr.BarPlot(show_label=False)
110107
display.change(bar_plot_fn, inputs=display, outputs=plot)
111108
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)
109+
110+
if __name__ == "__main__":
111+
bar_plot.launch()

demo/native_plots/line_plot_demo.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ def line_plot_fn(dataset):
2424
overlay_point=False,
2525
title="Stock Prices",
2626
stroke_dash_legend_title=None,
27-
height=300,
28-
width=500
2927
)
3028
elif dataset == "climate":
3129
return gr.LinePlot(
@@ -40,8 +38,6 @@ def line_plot_fn(dataset):
4038
overlay_point=False,
4139
title="Climate",
4240
stroke_dash_legend_title=None,
43-
height=300,
44-
width=500
4541
)
4642
elif dataset == "seattle_weather":
4743
return gr.LinePlot(
@@ -56,8 +52,6 @@ def line_plot_fn(dataset):
5652
overlay_point=True,
5753
title="Seattle Weather",
5854
stroke_dash_legend_title=None,
59-
height=300,
60-
width=500
6155
)
6256
elif dataset == "gapminder":
6357
return gr.LinePlot(
@@ -72,20 +66,15 @@ def line_plot_fn(dataset):
7266
overlay_point=False,
7367
title="Life expectancy for countries",
7468
stroke_dash_legend_title="Country Cluster",
75-
height=300,
76-
width=500
7769
)
7870

7971

8072
with gr.Blocks() as line_plot:
81-
with gr.Row():
82-
with gr.Column():
83-
dataset = gr.Dropdown(
84-
choices=["stocks", "climate", "seattle_weather", "gapminder"],
85-
value="stocks",
86-
)
87-
with gr.Column():
88-
plot = gr.LinePlot()
73+
dataset = gr.Dropdown(
74+
choices=["stocks", "climate", "seattle_weather", "gapminder"],
75+
value="stocks",
76+
)
77+
plot = gr.LinePlot()
8978
dataset.change(line_plot_fn, inputs=dataset, outputs=plot)
9079
line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot)
9180

demo/native_plots/scatter_plot_demo.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ def scatter_plot_fn(dataset):
1212
value=iris,
1313
x="petalWidth",
1414
y="petalLength",
15-
color="species",
15+
color=None,
1616
title="Iris Dataset",
17-
color_legend_title="Species",
1817
x_title="Petal Width",
1918
y_title="Petal Length",
2019
tooltip=["petalWidth", "petalLength", "species"],
2120
caption="",
21+
height=600,
22+
width=600,
2223
)
2324
else:
2425
return gr.ScatterPlot(
@@ -29,17 +30,15 @@ def scatter_plot_fn(dataset):
2930
tooltip="Name",
3031
title="Car Data",
3132
y_title="Miles per Gallon",
32-
color_legend_title="Origin of Car",
3333
caption="MPG vs Horsepower of various cars",
34+
height=None,
35+
width=None,
3436
)
3537

3638

3739
with gr.Blocks() as scatter_plot:
38-
with gr.Row():
39-
with gr.Column():
40-
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
41-
with gr.Column():
42-
plot = gr.ScatterPlot(show_label=False)
40+
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
41+
plot = gr.ScatterPlot(show_label=False)
4342
dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
4443
scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)
4544

gradio/components/bar_plot.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from typing import TYPE_CHECKING, Any, Callable, Literal
67

78
from gradio_client.documentation import document
@@ -52,8 +53,8 @@ def __init__(
5253
"none",
5354
]
5455
| None = None,
55-
height: int | str | None = None,
56-
width: int | str | None = None,
56+
height: int | None = None,
57+
width: int | None = None,
5758
y_lim: list[int] | None = None,
5859
caption: str | None = None,
5960
interactive: bool | None = True,
@@ -88,8 +89,8 @@ def __init__(
8889
color_legend_title: The title given to the color legend. By default, uses the value of color parameter.
8990
group_title: The label displayed on top of the subplot columns (or rows if vertical=True). Use an empty string to omit.
9091
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
91-
height: The height of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
92-
width: The width of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
92+
height: The height of the plot in pixels.
93+
width: The width of the plot in pixels. If None, expands to fit.
9394
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
9495
caption: The (optional) caption to display below the plot.
9596
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
@@ -122,10 +123,22 @@ def __init__(
122123
self.y_lim = y_lim
123124
self.caption = caption
124125
self.interactive_chart = interactive
126+
if isinstance(width, str):
127+
width = None
128+
warnings.warn(
129+
"Width should be an integer, not a string. Setting width to None."
130+
)
131+
if isinstance(height, str):
132+
warnings.warn(
133+
"Height should be an integer, not a string. Setting height to None."
134+
)
135+
height = None
125136
self.width = width
126137
self.height = height
127138
self.sort = sort
128139
self.show_actions_button = show_actions_button
140+
if label is None and show_label is None:
141+
show_label = False
129142
super().__init__(
130143
value=value,
131144
label=label,
@@ -172,8 +185,8 @@ def create_plot(
172185
"none",
173186
]
174187
| None = None,
175-
height: int | str | None = None,
176-
width: int | str | None = None,
188+
height: int | None = None,
189+
width: int | None = None,
177190
y_lim: list[int] | None = None,
178191
interactive: bool | None = True,
179192
sort: Literal["x", "y", "-x", "-y"] | None = None,
@@ -182,11 +195,7 @@ def create_plot(
182195
import altair as alt
183196

184197
interactive = True if interactive is None else interactive
185-
orientation = (
186-
{"field": group, "title": group_title if group_title is not None else group}
187-
if group
188-
else {}
189-
)
198+
orientation = {"field": group, "title": group_title} if group else {}
190199

191200
x_title = x_title or x
192201
y_title = y_title or y
@@ -234,14 +243,15 @@ def create_plot(
234243
properties["width"] = width
235244

236245
if color:
246+
color_legend_position = color_legend_position or "bottom"
237247
domain = value[color].unique().tolist()
238248
range_ = list(range(len(domain)))
239249
encodings["color"] = {
240250
"field": color,
241251
"type": "nominal",
242252
"scale": {"domain": domain, "range": range_},
243253
"legend": AltairPlot.create_legend(
244-
position=color_legend_position, title=color_legend_title or color
254+
position=color_legend_position, title=color_legend_title
245255
),
246256
}
247257

gradio/components/line_plot.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from typing import TYPE_CHECKING, Any, Callable, Literal
67

78
from gradio_client.documentation import document
@@ -64,8 +65,8 @@ def __init__(
6465
"none",
6566
]
6667
| None = None,
67-
height: int | str | None = None,
68-
width: int | str | None = None,
68+
height: int | None = None,
69+
width: int | None = None,
6970
x_lim: list[int] | None = None,
7071
y_lim: list[int] | None = None,
7172
caption: str | None = None,
@@ -101,8 +102,8 @@ def __init__(
101102
stroke_dash_legend_title: The title given to the stroke_dash legend. By default, uses the value of the stroke_dash parameter.
102103
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
103104
stroke_dash_legend_position: The position of the stoke_dash legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
104-
height: The height of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
105-
width: The width of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
105+
height: The height of the plot in pixels.
106+
width: The width of the plot in pixels. If None, expands to fit.
106107
x_lim: A tuple or list containing the limits for the x-axis, specified as [x_min, x_max].
107108
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
108109
caption: The (optional) caption to display below the plot.
@@ -136,9 +137,21 @@ def __init__(
136137
self.y_lim = y_lim
137138
self.caption = caption
138139
self.interactive_chart = interactive
140+
if isinstance(width, str):
141+
width = None
142+
warnings.warn(
143+
"Width should be an integer, not a string. Setting width to None."
144+
)
145+
if isinstance(height, str):
146+
warnings.warn(
147+
"Height should be an integer, not a string. Setting height to None."
148+
)
149+
height = None
139150
self.width = width
140151
self.height = height
141152
self.show_actions_button = show_actions_button
153+
if label is None and show_label is None:
154+
show_label = False
142155
super().__init__(
143156
value=value,
144157
label=label,
@@ -234,14 +247,15 @@ def create_plot(
234247
properties["width"] = width
235248

236249
if color:
250+
color_legend_position = color_legend_position or "bottom"
237251
domain = value[color].unique().tolist()
238252
range_ = list(range(len(domain)))
239253
encodings["color"] = {
240254
"field": color,
241255
"type": "nominal",
242256
"scale": {"domain": domain, "range": range_},
243257
"legend": AltairPlot.create_legend(
244-
position=color_legend_position, title=color_legend_title or color
258+
position=color_legend_position, title=color_legend_title
245259
),
246260
}
247261

gradio/components/scatter_plot.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from typing import TYPE_CHECKING, Any, Callable, Literal
67

78
from gradio_client.documentation import document
@@ -77,8 +78,8 @@ def __init__(
7778
"none",
7879
]
7980
| None = None,
80-
height: int | str | None = None,
81-
width: int | str | None = None,
81+
height: int | None = None,
82+
width: int | None = None,
8283
x_lim: list[int | float] | None = None,
8384
y_lim: list[int | float] | None = None,
8485
caption: str | None = None,
@@ -116,8 +117,8 @@ def __init__(
116117
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
117118
size_legend_position: The position of the size legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
118119
shape_legend_position: The position of the shape legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
119-
height: The height of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
120-
width: The width of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
120+
height: The height of the plot in pixels.
121+
width: The width of the plot in pixels. If None, expands to fit.
121122
x_lim: A tuple or list containing the limits for the x-axis, specified as [x_min, x_max].
122123
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
123124
caption: The (optional) caption to display below the plot.
@@ -151,11 +152,23 @@ def __init__(
151152
self.shape_legend_position = shape_legend_position
152153
self.caption = caption
153154
self.interactive_chart = interactive
155+
if isinstance(width, str):
156+
width = None
157+
warnings.warn(
158+
"Width should be an integer, not a string. Setting width to None."
159+
)
160+
if isinstance(height, str):
161+
warnings.warn(
162+
"Height should be an integer, not a string. Setting height to None."
163+
)
164+
height = None
154165
self.width = width
155166
self.height = height
156167
self.x_lim = x_lim
157168
self.y_lim = y_lim
158169
self.show_actions_button = show_actions_button
170+
if label is None and show_label is None:
171+
show_label = False
159172
super().__init__(
160173
value=value,
161174
label=label,
@@ -273,11 +286,12 @@ def create_plot(
273286
range_ = list(range(len(domain)))
274287
type_ = "nominal"
275288

289+
color_legend_position = color_legend_position or "bottom"
276290
encodings["color"] = {
277291
"field": color,
278292
"type": type_,
279293
"legend": AltairPlot.create_legend(
280-
position=color_legend_position, title=color_legend_title or color
294+
position=color_legend_position, title=color_legend_title
281295
),
282296
"scale": {"domain": domain, "range": range_},
283297
}
@@ -288,15 +302,15 @@ def create_plot(
288302
"field": size,
289303
"type": "quantitative" if is_numeric_dtype(value[size]) else "nominal",
290304
"legend": AltairPlot.create_legend(
291-
position=size_legend_position, title=size_legend_title or size
305+
position=size_legend_position, title=size_legend_title
292306
),
293307
}
294308
if shape:
295309
encodings["shape"] = {
296310
"field": shape,
297311
"type": "quantitative" if is_numeric_dtype(value[shape]) else "nominal",
298312
"legend": AltairPlot.create_legend(
299-
position=shape_legend_position, title=shape_legend_title or shape
313+
position=shape_legend_position, title=shape_legend_title
300314
),
301315
}
302316
chart = (

0 commit comments

Comments
 (0)