@@ -747,10 +747,10 @@ def apply_default_cascade(args):
747747 ]
748748
749749 # If both marginals and faceting are specified, faceting wins
750- if args .get ("facet_col" , None ) and args .get ("marginal_y" , None ):
750+ if args .get ("facet_col" , None ) is not None and args .get ("marginal_y" , None ):
751751 args ["marginal_y" ] = None
752752
753- if args .get ("facet_row" , None ) and args .get ("marginal_x" , None ):
753+ if args .get ("facet_row" , None ) is not None and args .get ("marginal_x" , None ):
754754 args ["marginal_x" ] = None
755755
756756
@@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables):
874874 "pandas MultiIndex is not supported by plotly express "
875875 "at the moment." % field
876876 )
877- ## ----------------- argument is a col name ----------------------
877+ # ----------------- argument is a col name ----------------------
878878 if isinstance (argument , str ) or isinstance (
879879 argument , int
880880 ): # just a column name given as str or int
@@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch):
10421042 args [position ] = args ["marginal" ]
10431043 args [other_position ] = None
10441044
1045+ if (
1046+ args .get ("marginal_x" , None ) is not None
1047+ or args .get ("marginal_y" , None ) is not None
1048+ or args .get ("facet_row" , None ) is not None
1049+ ):
1050+ args ["facet_col_wrap" ] = 0
1051+
10451052 # Compute applicable grouping attributes
10461053 for k in group_attrables :
10471054 if k in args :
@@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
10981105
10991106 orders , sorted_group_names = get_orderings (args , grouper , grouped )
11001107
1101- has_marginal_x = bool (args .get ("marginal_x" , False ))
1102- has_marginal_y = bool (args .get ("marginal_y" , False ))
1103-
11041108 subplot_type = _subplot_type_for_trace_type (constructor ().type )
11051109
11061110 trace_names_by_frame = {}
11071111 frames = OrderedDict ()
11081112 trendline_rows = []
11091113 nrows = ncols = 1
1114+ col_labels = []
1115+ row_labels = []
11101116 for group_name in sorted_group_names :
11111117 group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
11121118 mapping_labels = OrderedDict ()
@@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11881194 # Find row for trace, handling facet_row and marginal_x
11891195 if m .facet == "row" :
11901196 row = m .val_map [val ]
1191- trace ._subplot_row_val = val
1197+ if args ["facet_row" ] and len (row_labels ) < row :
1198+ row_labels .append (args ["facet_row" ] + "=" + str (val ))
11921199 else :
1193- if has_marginal_x and trace_spec .marginal != "x" :
1200+ if (
1201+ bool (args .get ("marginal_x" , False ))
1202+ and trace_spec .marginal != "x"
1203+ ):
11941204 row = 2
11951205 else :
11961206 row = 1
11971207
1198- nrows = max (nrows , row )
1199- if row > 1 :
1200- trace ._subplot_row = row
1201-
1208+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
12021209 # Find col for trace, handling facet_col and marginal_y
12031210 if m .facet == "col" :
12041211 col = m .val_map [val ]
1205- trace ._subplot_col_val = val
1212+ if args ["facet_col" ] and len (col_labels ) < col :
1213+ col_labels .append (args ["facet_col" ] + "=" + str (val ))
1214+ if facet_col_wrap : # assumes no facet_row, no marginals
1215+ row = 1 + ((col - 1 ) // facet_col_wrap )
1216+ col = 1 + ((col - 1 ) % facet_col_wrap )
12061217 else :
12071218 if trace_spec .marginal == "y" :
12081219 col = 2
12091220 else :
12101221 col = 1
12111222
1223+ nrows = max (nrows , row )
1224+ if row > 1 :
1225+ trace ._subplot_row = row
1226+
12121227 ncols = max (ncols , col )
12131228 if col > 1 :
12141229 trace ._subplot_col = col
@@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12381253 if show_colorbar :
12391254 colorvar = "z" if constructor == go .Histogram2d else "color"
12401255 range_color = args ["range_color" ] or [None , None ]
1241- d = len (args ["color_continuous_scale" ]) - 1
12421256
12431257 colorscale_validator = ColorscaleValidator ("colorscale" , "make_figure" )
12441258 layout_patch ["coloraxis1" ] = dict (
@@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12601274 layout_patch ["legend" ]["itemsizing" ] = "constant"
12611275
12621276 fig = init_figure (
1263- args , subplot_type , frame_list , ncols , nrows , has_marginal_x , has_marginal_y
1277+ args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels
12641278 )
12651279
12661280 # Position traces in subplots
@@ -1290,49 +1304,39 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12901304 return fig
12911305
12921306
1293- def init_figure (
1294- args , subplot_type , frame_list , ncols , nrows , has_marginal_x , has_marginal_y
1295- ):
1307+ def init_figure (args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels ):
12961308 # Build subplot specs
12971309 specs = [[{}] * ncols for _ in range (nrows )]
1298- column_titles = [None ] * ncols
1299- row_titles = [None ] * nrows
13001310 for frame in frame_list :
13011311 for trace in frame ["data" ]:
13021312 row0 = trace ._subplot_row - 1
13031313 col0 = trace ._subplot_col - 1
1304-
13051314 if isinstance (trace , go .Splom ):
13061315 # Splom not compatible with make_subplots, treat as domain
13071316 specs [row0 ][col0 ] = {"type" : "domain" }
13081317 else :
13091318 specs [row0 ][col0 ] = {"type" : trace .type }
1310- if args .get ("facet_row" , None ) and hasattr (trace , "_subplot_row_val" ):
1311- row_titles [row0 ] = args ["facet_row" ] + "=" + str (trace ._subplot_row_val )
1312-
1313- if args .get ("facet_col" , None ) and hasattr (trace , "_subplot_col_val" ):
1314- column_titles [col0 ] = (
1315- args ["facet_col" ] + "=" + str (trace ._subplot_col_val )
1316- )
13171319
13181320 # Default row/column widths uniform
13191321 column_widths = [1.0 ] * ncols
13201322 row_heights = [1.0 ] * nrows
13211323
13221324 # Build column_widths/row_heights
13231325 if subplot_type == "xy" :
1324- if has_marginal_x :
1326+ if bool ( args . get ( "marginal_x" , False )) :
13251327 if args ["marginal_x" ] == "histogram" or ("color" in args and args ["color" ]):
13261328 main_size = 0.74
13271329 else :
13281330 main_size = 0.84
13291331
13301332 row_heights = [main_size ] * (nrows - 1 ) + [1 - main_size ]
13311333 vertical_spacing = 0.01
1334+ elif args .get ("facet_col_wrap" , 0 ):
1335+ vertical_spacing = 0.07
13321336 else :
13331337 vertical_spacing = 0.03
13341338
1335- if has_marginal_y :
1339+ if bool ( args . get ( "marginal_y" , False )) :
13361340 if args ["marginal_y" ] == "histogram" or ("color" in args and args ["color" ]):
13371341 main_size = 0.74
13381342 else :
@@ -1351,15 +1355,25 @@ def init_figure(
13511355 vertical_spacing = 0.1
13521356 horizontal_spacing = 0.1
13531357
1358+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1359+ if facet_col_wrap :
1360+ subplot_labels = [None ] * nrows * ncols
1361+ while len (col_labels ) < nrows * ncols :
1362+ col_labels .append (None )
1363+ for i in range (nrows ):
1364+ for j in range (ncols ):
1365+ subplot_labels [i * ncols + j ] = col_labels [(nrows - 1 - i ) * ncols + j ]
1366+
13541367 # Create figure with subplots
13551368 fig = make_subplots (
13561369 rows = nrows ,
13571370 cols = ncols ,
13581371 specs = specs ,
13591372 shared_xaxes = "all" ,
13601373 shared_yaxes = "all" ,
1361- row_titles = list (reversed (row_titles )),
1362- column_titles = column_titles ,
1374+ row_titles = [] if facet_col_wrap else list (reversed (row_labels )),
1375+ column_titles = [] if facet_col_wrap else col_labels ,
1376+ subplot_titles = subplot_labels if facet_col_wrap else [],
13631377 horizontal_spacing = horizontal_spacing ,
13641378 vertical_spacing = vertical_spacing ,
13651379 row_heights = row_heights ,
0 commit comments