#!/usr/bin/env python3
# coding: utf8
# Interface graphique pour illustrer la corrélation
# Licence libre WTFPL

import tkinter
import numpy as np
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from matplotlib.pyplot import tight_layout
from sys import argv

# pour le lancement via ipython, et lecture du workspace
try:
    signal_perso = True
    _ = signal
except:
    print("Variable « signal » non définie, signal personnel indisponible.")
    signal_perso = None

# Définition des signaux #############################################
Fe = 50     # y en faut bien une dans ce cas
A = 1       # amplitude de la fréquence pure
Ts = 5      # durée maximale du signal
Tp = 0.2    # durée relative de la fenêtre
T = Tp*Ts   # durée de la fenêtre

N = Ts*Fe   # nombre de points jusqu'à +- 1s, pas des Hertz
ZP = 8      # facteur de zéro padding, simulation d'une f continue
K = N*ZP
n = np.arange(N)
k = np.arange(K)
t = n/Fe
axef = k*Fe/K - Fe/2

res_freq = 1/500  # pas de fréquence pour le slider
res_phase = np.pi/100  # de phase
res_durfen = 1  # 1% par 1% de la durée de fenêtre

inc = (np.sin(2*np.pi*10*t) + np.sin(2*np.pi*9.83*t+np.pi/7)
       + 1e-2*np.sin(2*np.pi*12*t) + 1e-4*np.sin(2*np.pi*20*t))

if signal_perso:
    perso_sig = signal
    perso_N = len(perso_sig)
    try:
        perso_Fe = Fesig
        print("Signal chargé, N=" + str(perso_N) + " points, durée "
              + str(perso_N/perso_Fe) + " secondes")
    except:
        print("Variable « Fesig » inexistante, Fe est fixée à 1.")
        perso_Fe = 1
    perso_t = np.arange(perso_N)/perso_Fe
    perso_axef = (np.arange(perso_N*ZP)/(perso_N*ZP) - 1/2)*perso_Fe

# Classe de base #####################################################
class Spectre:
    def __init__(self, t, Fe, axef, inc):
        # Initialisation des données #
        self.t_default = t
        self.Fe_default = Fe
        self.axef_default = axef
        self.sig_inc = inc
        self.list_signals = ["fréquence pure réelle",
                             "fréquence pure complexe",
                             "bruit blanc",
                             "signal à analyser"]
        if signal_perso:
            self.list_signals.append("signal personnel")
        self.list_fens = ["porte", "hanning", "hamming", "blackman"]
        self.list_rsb = ["10 dB", "3 dB", "0 dB", "-3 dB", "-10 dB"]
        self.freq = 0.13*Fe
        self.phase = 0
        self.durfen = 50 # %

        # Initialisation de l'affichage #
        self.root = tkinter.Tk()
        self.root.wm_title("Étude de la transformée de Fourier")

        # Défaut en plein écran #
        xmax = self.root.winfo_screenwidth()
        ymax = self.root.winfo_screenheight()-20
        self.root.geometry(str(xmax) + "x" + str(ymax) + "+0+0")

        self.fig = Figure()
        self.a1 = self.fig.add_subplot(311)
        self.a2 = self.fig.add_subplot(312)
        self.a3 = self.fig.add_subplot(313)
        self.fig.subplots_adjust()

        menu_width = 200

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
        self.canvas.get_tk_widget().place(width=xmax-menu_width, height=ymax)

        self.canvas.mpl_connect("key_press_event", self.key_bindings)

        # Création de la frame de placement de widgets #
        self.frame = tkinter.Frame(self.root, width=0.5)
        self.frame.place(x=xmax-menu_width, width=menu_width, height=ymax)

        # Création du menu de choix de signal et de choix de fenêtre #
        self.x = tkinter.StringVar()
        self.choix_x = tkinter.OptionMenu(self.frame, self.x,
                                          *self.list_signals,
                                          command=self.cree_signal)
        self.x.set(self.list_signals[0])
        labelx = tkinter.Label(self.frame, text="signal x(t)")
        labelx.place(rely=0.05, x=10)
        self.choix_x.place(rely=0.05, x=5, y=20)  # 20 : taille du label

        self.fen = tkinter.StringVar()
        self.choix_fen = tkinter.OptionMenu(self.frame, self.fen,
                                            *self.list_fens,
                                            command=self.cree_signal)
        self.fen.set(self.list_fens[0])
        labelfen = tkinter.Label(self.frame, text="Choix du fenêtrage")
        labelfen.place(rely=0.8, x=10)
        self.choix_fen.place(rely=0.8, x=5, y=20)  # 20 : taille du label

        self.rsb = tkinter.StringVar()
        self.choix_rsb = tkinter.OptionMenu(self.frame, self.rsb,
                                            *self.list_rsb,
                                            command=self.change_rsb)
        self.rsb.set(self.list_rsb[0])
        # il faut supprimer « dB » à la fin et conversion linéaire
        self.rsb_lin = 10**(float(self.list_rsb[0].split()[0])/10)
        self.labelrsb = tkinter.Label(self.frame,
                                      text="Rapport Signal à Bruit")
        # placement fait uniquement lorsque « signal bruité » est coché
        # self.labelrsb.place(rely=0.45, x=10)
        # self.choix_rsb.place(rely=0.45, x=5, y=20)  # 20 : taille du label

        # Création des choix de paramètres #
        self.scale_freq = tkinter.Scale(master=self.frame, orient='horizontal',
                                        from_=-0.1*self.Fe_default, to=0.6*self.Fe_default,
                                        resolution=res_freq*self.Fe_default,
                                        length=menu_width-20,
                                        command=self.change_freq)
        self.scale_freq.set(self.freq)
        labelfreq = tkinter.Label(self.frame,
                                  text="Fréquence de la fréq. pure")
        labelfreq.place(rely=0.15, x=10)
        self.scale_freq.place(rely=0.15, x=5, y=20)  # 20 : taille du label

        self.scale_phase = tkinter.Scale(master=self.frame,
                                         orient='horizontal',
                                         from_=-np.pi, to=np.pi,
                                         resolution=res_phase,
                                         length=menu_width-20,
                                         command=self.change_phase)

        self.scale_phase.set(self.phase)
        labelphase = tkinter.Label(self.frame,
                                   text="Phase de la fréq. pure")
        labelphase.place(rely=0.25, x=10)
        self.scale_phase.place(rely=0.25, x=10, y=20)  # 20 : taille du label

        self.scale_durfen = tkinter.Scale(master=self.frame,
                                          orient='horizontal',
                                          from_=0, to=100,
                                          resolution=res_durfen,
                                          length=menu_width-20,
                                          command=self.change_durfen)
        self.scale_durfen.set(self.durfen)
        labeldurfen = tkinter.Label(self.frame, text="Durée de la fenêtre (%)")
        labeldurfen.place(rely=0.9, x=30)
        self.scale_durfen.place(rely=0.9, x=10, y=20)  # 20 : taille du label

        # case des décibels et du bruitage #
        self.dB = tkinter.IntVar()
        self.check_dB = tkinter.Checkbutton(master=self.frame,
                                            text="Affichage en dB",
                                            variable=self.dB,
                                            onvalue=True,
                                            offvalue=False,
                                            command=self.cree_signal)
        self.check_dB.place(rely=0.65, x=10)
        self.dB.set(False)

        self.bruit = tkinter.IntVar()
        self.check_bruit = tkinter.Checkbutton(master=self.frame,
                                               text="Signal bruité",
                                               variable=self.bruit,
                                               onvalue=True,
                                               offvalue=False,
                                               command=self.ajoute_bruit)
        self.check_bruit.place(rely=0.4, x=10)
        self.bruit.set(False)

        # Affichage initial #
        self.cree_signal()
        self.fig.tight_layout()  # répétition de affiche()
        self.canvas.draw()       # pour initialiser correctement

    # callback des différents choix #
    def change_freq(self, data):
        self.freq = float(data)
        self.cree_signal()

    def change_phase(self, data):
        self.phase = float(data)
        self.cree_signal()

    def change_durfen(self, data):
        self.durfen = float(data)/100  # pourcentage
        self.cree_signal()

    def change_rsb(self, data):
        # il faut supprimer « dB » à la fin et conversion linéaire
        self.rsb_lin = 10**(float(data.split()[0])/10)
        self.cree_signal()

    def ajoute_bruit(self):
        if self.bruit.get():
            self.labelrsb.place(rely=0.45, x=10)
            self.choix_rsb.place(rely=0.45, x=5, y=20)  # 20 : taille du label
        else:
            self.labelrsb.place_forget()
            self.choix_rsb.place_forget()
        self.cree_signal()

    def cree_signal(self, data="osef"):
        self.t = self.t_default
        self.Fe = self.Fe_default
        self.axef = self.axef_default

        if self.x.get() == self.list_signals[2]:   # réalisation de bruit blanc
            self.sig = np.random.randn(len(self.t))
        elif self.x.get() == self.list_signals[3]:  # signal inconnu
            self.sig = self.sig_inc
        elif signal_perso and self.x.get() == self.list_signals[4]:
            self.sig = perso_sig
            self.t = perso_t
            self.Fe = perso_Fe
            self.axef = perso_axef
        else:
            self.sig = np.exp(2*np.pi*1j*self.freq*t+1j*self.phase)

        if self.bruit.get():
            r = np.random.randn(2, len(self.sig))/np.sqrt(2)  # puissance unité
            bruit = (r[0, :] + 1j*r[1, :]) / self.rsb_lin
            self.sig = self.sig*(1+0j) + bruit

        N = len(self.sig)
        h = np.zeros(self.sig.shape)
        Nh = int(max(2, min(N, np.round(self.durfen*N))))

        if self.fen.get() == "porte":
            fen = np.ones(Nh)
        elif self.fen.get() == "hanning":
            fen = np.hanning(Nh)
        elif self.fen.get() == "hamming":
            fen = np.hamming(Nh)
        elif self.fen.get() == "blackman":
            fen = np.blackman(Nh)
        h[:len(fen)] = fen
        self.sig = self.sig*h

        if self.x.get() != self.list_signals[1]:  # signal réel
            self.sig = np.real(self.sig)

        # calcul de la TF centrée en 0, normalisée pour TFD → TF
        self.SIG = np.fft.fftshift(np.fft.fft(self.sig, N*ZP))/self.Fe

        self.affiche()

    # fonction d'affichage #
    def affiche(self):
        self.affiche1()
        self.affiche2_3()
        self.fig.tight_layout()
        self.canvas.draw()

    def affiche1(self):
        self.a1.clear()
        if self.x.get() == self.list_signals[1]:  # complexe
            self.a1.plot(self.t, np.real(self.sig), linewidth=2,
                         label="partie réelle")
            self.a1.plot(self.t, np.imag(self.sig), linewidth=2,
                         label="partie imaginaire")
            self.a1.legend()
        else:
            self.a1.plot(self.t, self.sig, linewidth=2)
        self.a1.set_xlim((self.t[0], self.t[-1]))
        maxy = np.max(np.abs(self.sig))
        self.a1.set_ylim((-1.1*maxy, 1.1*maxy))

        self.a1.set_title('x(t), échantilloné à '+str(self.Fe)+' Hz')
        self.a1.grid("on")
        self.a1.set_ylabel('amplitude')
        self.a1.set_xlabel('temps (s)')

    def affiche2_3(self):
        self.a2.clear()
        self.a3.clear()
        if self.dB.get():
            spectre = 20*np.log10(np.abs(self.SIG))
            self.a2.set_ylabel('amplitude (dB)')
        else:
            spectre = np.abs(self.SIG)
            self.a2.set_ylabel('amplitude')

        self.a2.plot(self.axef, spectre, linewidth=2)
        self.a3.plot(self.axef, np.real(self.SIG), linewidth=2,
                     label="partie réelle")
        self.a3.plot(self.axef, np.imag(self.SIG), linewidth=2,
                     label="partie imaginaire")

        if self.dB.get():
            self.a2.set_ylim((-100+np.max(spectre), 3+np.max(spectre)))
        else:
            self.a2.set_ylim((-0.1*np.max(spectre), 1.1*np.max(spectre)))

        self.a2.set_xlim((self.axef[0], self.axef[-1]))
        self.a2.set_title('|X(f)|')
        self.a2.grid("on")
        self.a2.set_xlabel('fréquences (Hz)')

        self.a3.legend()
        self.a3.set_xlim((self.axef[0], self.axef[-1]))
        maxy = np.max(np.abs(self.SIG))
        self.a3.set_ylim((-1.1*maxy, 1.1*maxy))
        self.a3.set_title('Parties réelles et imaginaires de X(f)')
        self.a3.grid("on")
        self.a3.set_xlabel('fréquences (Hz)')

    def key_bindings(self, event):
        if event.key in ["x"]:  # x(t) suivant
            idx = self.list_signals.index(self.x.get())
            new_idx = (idx + 1) % len(self.list_signals)
            self.x.set(self.list_signals[new_idx])
        if event.key in ["X"]:  # x(t) précédent
            idx = self.list_signals.index(self.x.get())
            new_idx = (idx - 1) % len(self.list_signals)
            self.x.set(self.list_signals[new_idx])

        if event.key in ["F"]:  # diminuer la freq
            self.scale_freq.set(self.freq-res_freq*self.Fe)
        if event.key in ["f"]:  # augmenter la freq
            self.scale_freq.set(self.freq+res_freq*self.Fe)
        if event.key in ["ctrl+F"]:  # diminuer la freq
            self.scale_freq.set(self.freq-10*res_freq*self.Fe)
        if event.key in ["ctrl+f"]:  # augmenter la freq
            self.scale_freq.set(self.freq+10*res_freq*self.Fe)

        if event.key in ["P"]:  # diminuer la phase
            self.scale_phase.set(self.phase-res_phase)
        if event.key in ["p"]:  # augmenter la phase
            self.scale_phase.set(self.phase+res_phase)
        if event.key in ["ctrl+P"]:  # diminuer la phase
            self.scale_phase.set(self.phase-10*res_phase)
        if event.key in ["ctrl+p"]:  # augmenter la phase
            self.scale_phase.set(self.phase+10*res_phase)

        if event.key in ["left"]:  # diminuer la durée de fenêtre
            self.scale_durfen.set(self.durfen*100-res_durfen)  # en pourcentage
        if event.key in ["right"]:  # augmenter la durée de fenêtre
            self.scale_durfen.set(self.durfen*100+res_durfen)
        if event.key in ["ctrl+left"]:  # diminuer la durée de fenêtre
            self.scale_durfen.set(self.durfen*100-10*res_durfen)
        if event.key in ["ctrl+right"]:  # augmenter la durée de fenêtre
            self.scale_durfen.set(self.durfen*100+10*res_durfen)


        if event.key in ["d", "D"]:  # basculer les décibels
            self.dB.set(not(self.dB.get()))

        if event.key in ["q", "Q"]:  # quitter
            self.root.quit()
            self.root.destroy()
        else:
            self.cree_signal()


spec = Spectre(t, Fe, axef, inc)
tkinter.mainloop()
