#!/usr/bin/env python3
# coding: utf8
# Interface graphique pour illustrer la corrélation
# Licence libre WTFPL

import tkinter
import numpy as np
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from matplotlib.pyplot import tight_layout

# résolution de τ
res_tau = 0.005

# Définition des signaux#############################################
Fe = 1e4   # approximation continue
N = 2*Fe   # nombre de points jusqu'à +- 1s
n = np.arange(-(N/2), N/2)
t = n / Fe
T = 1/3    # taille des porte et triangle
f = 13     # fréquence des sinus et cosinus

por = np.zeros(t.shape)
por[np.abs(t) < T/2] = 1

tri = 1 - np.abs(t) / (T/2)
tri[tri < 0] = 0

cosin = np.cos(2*np.pi*f*t)
sinus = np.sin(2*np.pi*f*t)

# séquence SBPA précréée
sbpa = np.hstack((np.ones(10), -np.ones(7), np.ones(3), -np.ones(4), 1))
# on répète ces points par un facteur donné pour être visible
# et on remet ça en ligne
fac = 50
seq = np.array([i*np.ones(50) for i in sbpa])
seq = np.reshape(seq, -1)

ex1 = np.zeros(t.shape)
ex2 = np.zeros(t.shape)

n0 = int(np.where(t == 0)[0])
n1 = int(np.where(t == 0.25)[0])
seqlen = int(len(seq))

ex1[n0:n0+seqlen] = seq
ex2[n1:n1+seqlen] = seq

# dictionnaire des signaux
dic_signals = {}
dic_signals["porte"] = por
dic_signals["triangle"] = tri
dic_signals["cosinus"] = cosin
dic_signals["sinus"] = sinus
dic_signals["cosinus fini"] = cosin*por
dic_signals["sinus fini"] = sinus*por
dic_signals["signal test"] = ex1
dic_signals["test retardé"] = ex2
list_signals = ["porte", "triangle", "cosinus", "sinus", "cosinus fini",
                "sinus fini", "signal test", "test retardé"]


# Classe de base #####################################################
class Correlation:
    def __init__(self, t, dic_signals, list_signals):
        # Initialisation des données #
        self.t = t
        self.dic_signals = dic_signals
        self.list_signals = list_signals
        self.sig1 = self.dic_signals["porte"]
        self.sig2 = self.dic_signals["porte"]
        self.tau = 0
        self.list_rsb = ["10 dB", "3 dB", "0 dB", "-3 dB", "-10 dB"]

        # Initialisation de l'affichage #
        self.root = tkinter.Tk()
        self.root.wm_title("Étude de la corrélation")

        # Défaut en plein écran #
        xmax = self.root.winfo_screenwidth()
        ymax = self.root.winfo_screenheight()-20
        self.root.geometry(str(xmax) + "x" + str(ymax) + "+0+0")

        self.fig = Figure()
        self.a1 = self.fig.add_subplot(411)
        self.a2 = self.fig.add_subplot(412)
        self.a3 = self.fig.add_subplot(413)
        self.a4 = self.fig.add_subplot(414)
        self.fig.subplots_adjust()

        menu_width = 120

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
        self.canvas.get_tk_widget().place(width=xmax-menu_width, height=ymax)

        self.canvas.mpl_connect("key_press_event", self.key_bindings)

        # Création des widgets #
        self.frame = tkinter.Frame(self.root, width=0.5)
        self.frame.place(x=xmax-menu_width, width=menu_width, height=ymax)

        self.scale_tau = tkinter.Scale(master=self.frame, orient='horizontal',
                                       from_=-0.5, to=0.5, resolution=res_tau,
                                       length=menu_width-20,
                                       command=self.change_tau)

        self.scale_tau.set(self.tau)
        labeltau = tkinter.Label(self.frame, text="retard τ")
        labeltau.place(rely=0.83, x=30)
        self.scale_tau.place(rely=0.83, x=10, y=20)  # 20 : taille du label

        # Création des menus de choix de signaux #

        self.x = tkinter.StringVar()
        self.choix_x = tkinter.OptionMenu(self.frame, self.x,
                                          *self.list_signals,
                                          command=self.change_x)
        self.x.set(self.list_signals[0])
        labelx = tkinter.Label(self.frame, text="signal x(t)")
        labelx.place(rely=0.1, x=10)
        self.choix_x.place(rely=0.1, x=5, y=20)  # 20 : taille du label

        self.y = tkinter.StringVar()
        self.choix_y = tkinter.OptionMenu(self.frame, self.y,
                                          *self.list_signals,
                                          command=self.change_y)
        self.y.set(self.list_signals[0])
        labely = tkinter.Label(self.frame, text="signal y(t)")
        labely.place(rely=0.3, x=10)
        self.choix_y.place(rely=0.3, x=5, y=20)  # 20 : taille du label

        # Gestion du bruit
        self.bruit = tkinter.IntVar()
        self.check_bruit = tkinter.Checkbutton(master=self.frame,
                                               text="Signal bruité",
                                               variable=self.bruit,
                                               onvalue=True,
                                               offvalue=False,
                                               command=self.ajoute_bruit)
        self.check_bruit.place(rely=0.4, x=5)
        self.bruit.set(False)

        self.rsb = tkinter.StringVar()
        self.choix_rsb = tkinter.OptionMenu(self.frame, self.rsb,
                                            *self.list_rsb,
                                            command=self.change_rsb)
        self.rsb.set(self.list_rsb[0])
        # il faut supprimer « dB » à la fin et conversion linéaire
        self.rsb_lin = 10**(float(self.list_rsb[0].split()[0])/10)
        self.labelrsb = tkinter.Label(self.frame,
                                      text="RSB")
        # placement fait uniquement lorsque « signal bruité » est coché
        # → dans self.ajoute_bruit

        self.memebruit = tkinter.IntVar()
        self.meme_bruit = tkinter.Checkbutton(master=self.frame,
                                              text="Même bruit",
                                              variable=self.memebruit,
                                              onvalue=True,
                                              offvalue=False,
                                              command=self.ajoute_bruit)
        self.memebruit.set(False)

        self.pointzero = tkinter.IntVar()
        self.point_zero = tkinter.Checkbutton(master=self.frame,
                                              text="Point central",
                                              variable=self.pointzero,
                                              onvalue=True,
                                              offvalue=False,
                                              command=self.affiche)
        self.pointzero.set(True)


        # Affichage initial #
        self.affiche()
        self.fig.tight_layout()  # répétition de affiche()
        self.canvas.draw()       # pour initialiser correctement

    # callback des différents choix #
    def change_tau(self, tau):
        self.tau = float(tau)
        self.affiche()

    def change_x(self, signal):
        self.sig1 = self.dic_signals[self.x.get()]
        self.bruite()

    def change_y(self, signal):
        self.sig2 = self.dic_signals[self.y.get()]
        self.bruite()

    def bruite(self):
        if self.bruit.get():
            bruit = np.random.randn(len(self.sig1))/self.rsb_lin
        else:
            bruit = 0
        self.sig1 = self.dic_signals[self.x.get()] + bruit
        if self.bruit.get() and not self.memebruit.get():
            bruit = np.random.randn(len(self.sig1))/self.rsb_lin
        self.sig2 = self.dic_signals[self.y.get()] + bruit
        self.affiche()

    def ajoute_bruit(self):
        if self.bruit.get():
            self.meme_bruit.place(rely=0.45, x=5)
            self.labelrsb.place(rely=0.5, x=10)
            self.choix_rsb.place(rely=0.5, x=5, y=20)  # 20 : taille du label
            if self.memebruit.get():
                self.point_zero.place(rely=0.75, x=5)
            else:
                self.point_zero.place_forget()
        else:
            self.meme_bruit.place_forget()
            self.labelrsb.place_forget()
            self.choix_rsb.place_forget()
            self.point_zero.place_forget()
        self.bruite()

    def change_rsb(self, data):
        # il faut supprimer « dB » à la fin et conversion linéaire
        self.rsb_lin = 10**(float(data.split()[0])/10)
        self.bruite()

    # fonction d'affichage #
    def affiche(self):
        self.affiche1()
        self.affiche2()
        self.affiche3()
        self.affiche4()
        self.fig.tight_layout()
        self.canvas.draw()

    def affiche1(self):
        self.a1.clear()
        self.a1.plot(self.t, self.sig1, linewidth=2)
        self.a1.set_xlim((-0.5, 0.5))
        maxy = np.max(self.sig1)
        miny = np.min(self.sig1)
        self.a1.set_ylim((1.1*miny, 1.1*maxy))
        self.a1.set_title('x(t)')
        self.a1.grid("on")
        self.a1.set_ylabel('amplitude')
        self.a1.set_xlabel('temps t')

    def affiche2(self):
        self.a2.clear()
        self.a2.plot(self.t+self.tau, self.sig2, linewidth=2)
        self.a2.set_xlim((-0.5, 0.5))
        maxy = np.max(self.sig2)
        miny = np.min(self.sig2)
        self.a2.set_ylim((1.1*miny, 1.1*maxy))
        self.a2.set_title(r'$y(t-\tau)$')
        self.a2.grid("on")
        self.a2.set_ylabel('amplitude')
        self.a2.set_xlabel('temps t')

    def affiche3(self):
        ntau = np.round(self.tau*Fe)
        s1 = self.sig1[np.abs(n) < (Fe/2)]
        s2 = self.sig2[np.abs(n + ntau) < (Fe/2)]
        # equivalent to abs(t+tau) < 0.5, without round errors
        axet = self.t[abs(self.t) < 0.5]
        self.a3.clear()
        self.a3.plot(axet, s1*s2, linewidth=2)
        self.a3.fill_between(axet, 0, s1*s2, color="green", alpha=0.2)
        self.a3.set_xlim((-0.5, 0.5))
        maxy = np.max(s1*s2)
        miny = np.min(s1*s2)
        self.a3.set_ylim((1.1*miny, 1.1*maxy))
        self.a3.set_title(r'$x(t) \times y(t-\tau)$')
        self.a3.grid("on")
        self.a3.set_ylabel('amplitude')
        self.a3.set_xlabel('temps t')

    def affiche4(self):
        self.a4.clear()
        r = np.correlate(self.sig1, self.sig2, mode="full")
        lags = np.arange(-len(self.sig1)+1, len(self.sig1))
        r = r/Fe  # normalisation pour le continu
        if self.memebruit.get() and not self.pointzero.get():
            r[lags == 0] = None
        self.a4.plot(lags/Fe, r, linewidth=2)
        self.a4.plot((self.tau, self.tau), (1.5*min(r), 1.5*max(r)),
                     color='r', linewidth=2)
        self.a4.set_xlim((-0.5, 0.5))
        self.a4.set_ylim((1.1*min(r), 1.1*max(r)))
        self.a4.set_title(r'$\gamma_{xy}(\tau) = \int x(t)y(t-\tau)dt$')
        self.a4.grid("on")
        self.a4.set_ylabel('amplitude')
        self.a4.set_xlabel(r'retard $\tau$')

    def key_bindings(self, event):
        if event.key in ["x"]:  # x(t) suivant
            idx = self.list_signals.index(self.x.get())
            new_idx = (idx + 1) % len(self.list_signals)
            self.x.set(self.list_signals[new_idx])
            self.change_x(self.list_signals[new_idx])

        if event.key in ["X"]:  # x(t) précédent
            idx = self.list_signals.index(self.x.get())
            new_idx = (idx - 1) % len(self.list_signals)
            self.x.set(self.list_signals[new_idx])
            self.change_x(self.list_signals[new_idx])

        if event.key in ["y"]:  # y(t) suivant
            idx = self.list_signals.index(self.y.get())
            new_idx = (idx + 1) % len(self.list_signals)
            self.y.set(self.list_signals[new_idx])
            self.change_y(self.list_signals[new_idx])

        if event.key in ["Y"]:  # y(t) précédent
            idx = self.list_signals.index(self.y.get())
            new_idx = (idx - 1) % len(self.list_signals)
            self.y.set(self.list_signals[new_idx])
            self.change_y(self.list_signals[new_idx])

        if event.key in ["q", "Q"]:  # quitter
            self.root.quit()
            self.root.destroy()
        if event.key in ["left"]:  # diminuer τ
            self.scale_tau.set(self.tau-res_tau)
        if event.key in ["right"]:  # augmenter τ
            self.scale_tau.set(self.tau+res_tau)
        if event.key in ["ctrl+left"]:  # plus diminuer τ
            self.scale_tau.set(self.tau-10*res_tau)
        if event.key in ["ctrl+right"]:  # plus augmenter τ
            self.scale_tau.set(self.tau+10*res_tau)


corr = Correlation(t, dic_signals, list_signals)
tkinter.mainloop()
