Source code for pyleecan.Functions.Plot.plot_A_2D

# -*- coding: utf-8 -*-

from itertools import repeat

import matplotlib.pyplot as plt
from numpy import where, argmin, abs, squeeze, split, ndarray

from ...Functions.init_fig import init_subplot, init_fig
from ...definitions import config_dict

FONT_NAME = config_dict["PLOT"]["FONT_NAME"]


[docs]def plot_A_2D( Xdatas, Ydatas, legend_list=[""], color_list=[(0, 0, 1, 0.5)], linestyle_list=["-"], linewidth_list=[3], title="", xlabel="", ylabel="", fig=None, subplot_index=None, is_logscale_x=False, is_logscale_y=False, is_disp_title=True, is_grid=True, type="curve", is_fund=False, fund_harm=None, x_min=None, x_max=None, y_min=None, y_max=None, xticks=None, save_path=None, ): """Plots a 2D graph (curve, bargraph or barchart) comparing fields in Ydatas Parameters ---------- Xdata : ndarray array of x-axis values Ydatas : list list of y-axes values legend_list : list list of legends color_list : list list of colors to use for each curve linewidth_list : list list of line width to use for each curve title : str title of the graph xlabel : str label for the x-axis ylabel : str label for the y-axis fig : Matplotlib.figure.Figure existing figure to use if None create a new one subplot_index : int index of subplot in which to plot is_logscale_x : bool boolean indicating if the x-axis must be set in logarithmic scale is_logscale_y : bool boolean indicating if the y-axis must be set in logarithmic scale is_disp_title : bool boolean indicating if the title must be displayed is_grid : bool boolean indicating if the grid must be displayed type : str type of 2D graph : "curve", "bargraph", "barchart" or "quiver" is_fund : bool boolean indicating if the bar corresponding to the fundamental must be displayed in red fund_harm : float frequency of the fundamental harmonic y_min : float minimum value for the y-axis y_max : float maximum value for the y-axis xticks : list list of ticks to use for the x-axis """ # Set figure/subplot is_show_fig = True if fig is None else False if fig is None: (fig, axes, patch_leg, label_leg) = init_fig(None, shape="rectangle") fig, ax = init_subplot(fig=fig, subplot_index=subplot_index) # Number of curves on a axe ndatas = len(Ydatas) # Retrocompatibility if isinstance(Xdatas, ndarray): Xdatas = [Xdatas] if len(Xdatas) == 1: i_Xdatas = [0 for i in range(ndatas)] else: i_Xdatas = range(ndatas) # Expend default argument if 1 == len(color_list) < ndatas: # Set the same color for all curves color_list = list(repeat(color_list[0], ndatas)) if 1 == len(linewidth_list) < ndatas: # Set the same color for all curves linewidth_list = list(repeat(linewidth_list[0], ndatas)) if 1 == len(linestyle_list) < ndatas: # Set the same linestyles for all curves linestyle_list = list(repeat(linestyle_list[0], ndatas)) if 1 == len(legend_list) < ndatas: # Set no legend for all curves legend_list = list(repeat("", ndatas)) no_legend = True else: no_legend = False # Plot if type == "curve": for i in range(ndatas): ax.plot( Xdatas[i_Xdatas[i]], Ydatas[i], color=color_list[i], label=legend_list[i], linewidth=linewidth_list[i], ls=linestyle_list[i], ) if xticks is not None: ax.xaxis.set_ticks(xticks) elif type == "bargraph": positions = range(-ndatas + 1, ndatas, 2) for i in range(ndatas): # width = (Xdatas[i_Xdatas[i]][1] - Xdatas[i_Xdatas[i]][0]) / ndatas width = Xdatas[i_Xdatas[i]][-1] / 100 barlist = ax.bar( Xdatas[i_Xdatas[i]] + positions[i] * width / (2 * ndatas), Ydatas[i], color=color_list[i], width=width, label=legend_list[i], ) if is_fund: # Find fundamental if fund_harm is None: mag_max = max(Ydatas[i]) imax = int(where(Ydatas[i] == mag_max)[0]) else: imax = argmin(abs(Xdatas[i] - fund_harm)) barlist[imax].set_edgecolor("k") if xticks is not None: ax.xaxis.set_ticks(xticks) elif type == "barchart": for i in range(ndatas): if i == 0: ax.bar( range(len(Xdatas[i_Xdatas[i]])), Ydatas[i], color=color_list[i], width=0.5, label=legend_list[i], ) else: ax.bar( range(len(Xdatas[i_Xdatas[i]])), Ydatas[i], edgecolor=color_list[i], width=0.5, fc="None", lw=1, label=legend_list[i], ) plt.xticks( range(len(Xdatas[i_Xdatas[i]])), [str(f) for f in Xdatas[i_Xdatas[i]]], rotation=90, ) elif type == "quiver": for i in range(ndatas): x = [e[0] for e in Xdatas[i_Xdatas[i]]] y = [e[1] for e in Xdatas[i_Xdatas[i]]] vect_list = split(Ydatas[i], 2) ax.quiver(x, y, squeeze(vect_list[0]), squeeze(vect_list[1])) ax.axis("equal") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xlim([x_min, x_max]) ax.set_ylim([y_min, y_max]) if is_logscale_x: ax.set_xscale("log") if is_logscale_y: ax.set_yscale("log") if is_disp_title: ax.set_title(title) if is_grid: ax.grid() if ndatas > 1 and not no_legend: ax.legend(prop={"family": FONT_NAME, "size": 22}) plt.tight_layout() for item in ( [ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels() ): item.set_fontname(FONT_NAME) item.set_fontsize(22) ax.title.set_fontname(FONT_NAME) ax.title.set_fontsize(24) if save_path is not None: fig.savefig(save_path) plt.close() if is_show_fig: fig.show() return ax