#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
COPYRIGHT NOTICE: THIS SCRIPT IS FOR THE SOLE USE OF STUDENTS 
ATTENDING THE MATH3474 COURSE. ANY PERSON WISHING TO COPY, USE OR 
PROCESS THIS MATERIAL IN ANY FORM FOR ANY OTHER PURPOSES SHOULD 
CONTACT THE AUTHOR VIA EMAIL.

THIS SCRIPT WAS ORIGINALLY PRODUCED IN MAPLE BY PROF MARK KELMANSON, AND
ADAPTED INTO PYTHON BY AUTHOR.

Joseph Elmes : NERC-Funded PhD Researcher in Applied Mathematics
University of Leeds : Leeds LS2 9JT : ml14je@leeds.ac.uk

Python 3.7 : Sat Aug 10 11:47:33 2019

MATH3474 : Section 1 : Lectures 7 to 11
SCRIPT MATH3474_3.py : CHEBYSHEV APPROXIMATION
"""

import numpy as np
import matplotlib.pyplot as pt
from MATH3474 import plot_setup, Chebyshev_grid, T_n, x_to_t, t_to_x
from math import pi

def y_coord(r, m, x_coord):
    """
    For a circle situated at (m, 0) with radius r, if you provide an x
    co-ordinate in [m-r, m+r], the function returns the corresponding positive
    y co-ordinate on the circle.
    
    Parameters
    ----------
    r : Float
        Radius of circle.
    m : Float
        Midpoint of circle which lies on the x-axis.
    x_coord : Float
        x-value which lies in domain [m-r, m+r], to which we return the
        corresponding y value.
    """
    return np.sqrt(r**2-(x_coord-m)**2)

def Cheb_interp_poly(f, n, sup_output=False):
    """
    Returns the Chebyshev interpolating polynomial of function f up to n first
    modes. This function will also print the coefficients and the final
    inteprolating function if sup_output == False.

    Parameters
    ----------
    f : Function
        Function of which we are constructing the Chebyshev interpolating
        polynomial.
    n: Integer
        Indicates the number of Chebyshev polynomials we use to construct the
        interpolating polynomial.
    sup_output : Boolean
        If True, the function will supress printing any output.
    """
    from MATH3474 import trapezium_rule

    Tn_coeff = np.zeros(n+1)
    if not sup_output: print('\nChebyshev interpolation:')
    if not sup_output: func = ''
    for k in range(n+1):
        g = lambda z: f(np.cos(z))*np.cos(k*z)
        c = trapezium_rule(g, -pi, pi, 128)/pi
        if k==0: c/=2
        if not sup_output: print('\tc_{} = {}'.format(k, round(c, 5)))
        Tn_coeff[k] = c
        if not sup_output: func+='{} T_{}(x)+'.format(round(c, 5), k)

    if not sup_output:  print('\n\tC_{}(x) = {}\n\n'.format(n, func[:-1]))
    
    return np.polynomial.chebyshev.Chebyshev(Tn_coeff)

def forced_osc_approx(f, n, sup_output=False):
    """
    Returns the forced oscillating interpolation polynomial of function f up
    to n first Chebyshev modes. This function will also print the coefficients and the final
    inteprolating function if sup_output == False.

    Parameters
    ----------
    f : Function
        Function of which we are constructing the forced oscillating
        interpolation polynomial.
    n: Integer
        Indicates the order of the polynomial we use to construct the
        interpolating polynomial.
    sup_output : Boolean
        If True, the function will supress printing any output.
    """
    Tn_coeff = np.zeros(n+1)
    x_k = np.cos(np.array(range(0, n+2))*pi/(n+1))
    
    def coefficient(j):
        S = 0
        for i in range(n+2):
            c_i = f(x_k[i])*np.cos(j*i*pi/(n+1))
            if i==0 or i==n+1: c_i/=2
            S+=c_i
        return 2*S/(n+1)

    if not sup_output: print('\nForced-oscillation approximation:')
    func = ''
    for k in range(n+1):
        c = coefficient(k)
        if k==0: c/=2
        if not sup_output: print('\tc_{} = {}'.format(k, round(c, 5)))
        Tn_coeff[k] = c
        if not sup_output: func+='{} T_{}(x)+'.format(round(c, 5), k)

    if not sup_output: print('\n\tF_{}(x) = {}\n\n'.format(n,
          func[:-1]))
    
    return np.polynomial.chebyshev.Chebyshev(Tn_coeff)

def forced_osc_error(f, n):
    """
    Returns the error of the n-degree forced oscillating polynomial of 
    function f.

    Parameters
    ----------
    f : Function
        The function whose forced oscillating interpolation polynomial we
        are evaluating the error.
    n : Integer
        The degree of the forced oscillating interpolating polynomial.
    """
    x_k = np.cos(np.array(range(0, n+2))*pi/(n+1))
    S = 0
    for k in range(n+2):
        a = 1/2 if k==0 or k==n+1 else 1
        S+=a*((-1)**k)*f(x_k[k])
        
    return S/(n+1)

def Example_11a(a, b, n):
    """
    Example 1.11a:
    Plots the construction of the first n Chebyshev polynomials of the first
    kind on [a, b].

    Parameters
    ----------
    a : Float
        The lower limit of the domain over which we plot the Chebyshev
        polynomial.
    b : Float
        The upper limit of the domain over which we plot the Chebyshev
        polynomial.
    n : Integer
        The degree of the Chebyshev polynomial.
    """
    fig, ax = plot_setup('$x$', scale=s)
    x_arr = np.linspace(a, b, 500)
    t_arr = x_to_t(x_arr, a, b)
    T = T_n(n, t_arr)

    for k in range(n):
        ax.plot(x_arr, T[k], 'k')

    pt.show()
    
def Example_11(a, b, n):
    """
    Example 1.11:
    Shows the construction of the Chebyshev grid, and how one physically
    observes both the roots and extreme points of the Chebyshev polynomial,
    T_n(x), forx in [a, b].

    Parameters
    ----------
    a : Float
        The lower limit of the domain over which we construct the Chebyshev
        grid.
    b : Float
        The upper limit of the domain over which we construct the Chebyshev
        grid.
    n : Integer
        Indicates the number of grid points.
    """
    from MATH3474 import T_n_func

    fig, ax = plot_setup('$x$', scale=s)
    
    x_arr = np.linspace(a, b, int(5E3))
    t_arr = x_to_t(x_arr, a, b)
    m, r = (a+b)/2, (b-a)/2
    X = Chebyshev_grid(a, b, n)

    T = T_n_func(n) #T_n(x)
    roots = t_to_x(np.cos(((2*np.array(range(1, n+1))-1)/(2*n))*np.pi), a, b)

    ax.plot(x_arr, y_coord(r, m, x_arr), color='k')
    for i in range(1, n):
        ax.plot([m, X[i], X[i]], [0, y_coord(r, m, X[i]), 0],
                color='g')
    ax.plot(x_arr, T(t_arr), label='$T_{{{}}}(x)$'.format(n)) 
    ax.plot(X, np.zeros(len(X)), 'og', markersize=s*10,
            fillstyle='none', label='Chebyshev grid points')
    ax.plot(roots, np.zeros(len(roots)), 'or', markersize=s*10,
            fillstyle='none', label='Roots to $T_{{{}}}(x)=0$'.format(n))
    
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),
              shadow=True, ncol=3, fontsize=16*s)
    pt.show()
    
def Example_14(f_symb, a, b, n):
    """
    Example 1.14:
    Compares the errors of f(x) using the Chebyshev interpolation C_n(x),
    and the minimax expansion q_n(x).
    
    Parameters
    ----------
    f_sumb : sympy.core.function
        Symbolic function of which we approximate and compare with the
        Chebyshev interpolating function and the minimax interpolation
        polynomial.
    a : Float
        The lower limit of the domain over which we consider f(x).
    b : Float
        The upper limit of the domain over which we consider f(x).
    n : Integer
        The degree of the Chebyshev and minimax interpolating polynomial
        function.
    """
    from sympy import lambdify, latex
    from sympy.abc import x
    from MATH3474_2 import minimax
    from MATH3474 import p_n, root_sqr_mean_err
 
    x_arr, f = np.linspace(a, b, int(5E2)), lambdify(x, f_symb, 'numpy')
    t_arr, g = x_to_t(x_arr, a, b), lambda y: f(t_to_x(y, a, b))
    C_n = Cheb_interp_poly(g, n, True)
    MM = minimax(f_symb, a, b, n, True)
    q = lambda y: p_n(MM[n:-1], t_to_x(y, a, b))
    F, C, Q = f(x_arr), C_n(t_arr), q(t_arr)
    
    fig, ax = plot_setup('$x$', 'Error', scale=s)
    labels = ['C', 'q^{{\\ast}}']
    colors = ['r', 'g']
    for func, vals, lab, c in zip([C_n, q], [C, Q], labels, colors):
        ax.plot(x_arr, F-vals, color=c,
                label = '${}-{}_{{{}}}(x)$'.format(latex(f_symb), lab, n))
        E = np.max(np.abs(F-vals))*np.ones(2)
        E2 = root_sqr_mean_err(g, func, a, b)
        print('Root-mean-square error on {}_n: {}'.format(lab,
              round(E2, 5)))
        for i in [1, -1]:
            ax.plot([a, b], i*E*np.ones(2), color=c)
            ax.plot([a, b], i*E2*np.ones(2), color=c, linestyle=':')
        
    
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),
              shadow=True, ncol=2, fontsize=16*s)
    pt.show()

def Example_14b(f_symb, a, b, n):
    """
    Example 1.14b:
    Compares f(x) and the Chebyshev interpolating polynomial C_n(x).
    
    Parameters
    ----------
    f_sumb : sympy.core.function
        Symbolic function of which we approximate using the n-th degree
        Chebyshev interpolating function.
    a : Float
        The lower limit of the domain over which we consider f(x).
    b : Float
        The upper limit of the domain over which we consider f(x).
    n : Integer
        The degree of the Chebyshev interpolating polynomial function.
    """
    from sympy import lambdify, latex
    from sympy.abc import x
    
    x_arr = np.linspace(a, b, int(5E3))
    t_arr = x_to_t(x_arr, a, b)
    f = lambdify(x, f_symb, 'numpy')
    g = lambda y: f(t_to_x(y, a, b))

    C_n = Cheb_interp_poly(g, n)

    fig, ax = plot_setup('$x$', scale=s)
    ax.plot(x_arr, f(x_arr), label = '${}$'.format(latex(f_symb)))
    ax.plot(x_arr, C_n(t_arr), label='$C_{{{}}}(x)$'.format(n), linestyle=':')
    
    ax.legend(fontsize=16*s)
    pt.show()

def Example_17(f_symb, a, b, n):
    """
    Example 1.17:
    Compares the errors of f(x) using the Chebyshev interpolation C_n(x),
    minimax expansion q_n(x), and the Chebyshev interpolating polynomial
    I_n(x).
    
    Parameters
    ----------
    f_sumb : sympy.core.function
        Symbolic function of which we approximate.
    a : Float
        The lower limit of the domain over which we consider f(x).
    b : Float
        The upper limit of the domain over which we consider f(x).
    n : Integer
        The degree of the interpolating polynomial functions.
    """
    from sympy import lambdify, latex
    from sympy.abc import x
    from MATH3474_1 import lagrange_interp
    from MATH3474_2 import minimax
    from MATH3474 import p_n, root_sqr_mean_err
 
    x_arr, f = np.linspace(a, b, int(5E2)), lambdify(x, f_symb, 'numpy')
    k, t_arr = np.array(range(0, n+1)), x_to_t(x_arr, a, b)
    X = t_to_x(np.cos(((2*k+1)/(2*(n+1)))*np.pi), a, b)
    MM = minimax(f_symb, a, b, n)
    g = lambda y: f(t_to_x(y, a, b)) #f(t)
    q = lambda y: p_n(MM[n:-1], t_to_x(y, a, b)) #q_n(t)
    i = lambda y: lagrange_interp(t_to_x(y, a, b), X, f(X)) #I_n(t)
    C_n = Cheb_interp_poly(g, n) #C_n(t)
    F, C, Q, I = f(x_arr), C_n(t_arr), q(t_arr), i(t_arr)
    
    fig, ax = plot_setup('$x$', 'Error', scale=s)
    labels = ['C' , 'q^{{\\ast}}', 'I']
    colors = ['r', 'g', 'b']
    for func, vals, lab, c in zip([C_n, q, i], [C, Q, I], labels, colors):
        ax.plot(x_arr, F-vals, color=c,
                label = '${}-{}_{{{}}}(x)$'.format(latex(f_symb), lab, n))
        E = np.max(np.abs(F-vals))
        E2 = root_sqr_mean_err(g, func, a, b)
        print('Root-mean-square error on {}_n: {}'.format(lab,
              round(E2, 5)))
        for i in [1, -1]:
            ax.plot([a, b], i*E*np.ones(2), color=c)
            ax.plot([a, b], i*E2*np.ones(2), color=c, linestyle=':')
    
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),
              shadow=True, ncol=3, fontsize=s*16)
    pt.show()

def Example_18(f_symb, a, b, n):
    """
    Example 1.18:
    Compares the errors of f(x) for x in [a, b] using the Chebyshev
    interpolation C_n(x), and the forced-oscillation interpolating polynomial,
    F_n(x).
    
    Parameters
    ----------
    f_sumb : sympy.core.function
        Symbolic function of which we approximate.
    a : Float
        The lower limit of the domain over which we consider f(x).
    b : Float
        The upper limit of the domain over which we consider f(x).
    n : Integer
        The degree of the interpolating polynomial functions.
    """
    from sympy import lambdify, latex
    from sympy.abc import x
    from MATH3474 import root_sqr_mean_err
    from MATH3474_2 import minimax
 
    x_arr, f = np.linspace(a, b, int(5E2)), lambdify(x, f_symb, 'numpy')
    t_arr, g = x_to_t(x_arr, a, b), lambda y: f(t_to_x(y, a, b))
    C_n, F_n = Cheb_interp_poly(g, n), forced_osc_approx(g, n)
    G, C, F = g(t_arr), C_n(t_arr), F_n(t_arr)
    phi, rho = forced_osc_error(g, n), minimax(f_symb, a, b, n)[-1]
    p_inf = np.max(np.abs(F-G))
    
    print('n={}:\n\t\\phi={}\n\t\\rho={}\n\tp_inf={}\n\n'.format(
            n, round(phi, 5), round(rho, 5), round(p_inf, 5)))
    
    fig, ax = plot_setup('$x$', 'Error', scale=s)
    labels = ['C', 'F']
    colors = ['r', 'b']
    E = [np.max(np.abs(G-C)), phi]
    for func, vals, lab, c, err in zip([C_n, F_n], [C, F], labels, colors, E):
        ax.plot(x_arr, G-vals, color=c,
                label = '${}-{}_{{{}}}(x)$'.format(latex(f_symb), lab, n))
        E2 = root_sqr_mean_err(g, func, a, b) # root-mean-squared error
        print('Root-mean-square error on {}_n: {}'.format(lab,
              round(E2, 5)))
        for i in [1, -1]: #plotting errors of C_n and F_n
            ax.plot([a, b], i*err*np.ones(2), color=c)
            ax.plot([a, b], i*E2*np.ones(2), color=c, linestyle=':')
          
    for i in [-1, 1]: #Plotting \rho_n
         ax.plot([a, b], i*rho*np.ones(2), color='g')

    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),
              shadow=True, ncol=2, fontsize=s*16)
    pt.show()

def Example_19(f_symb, c=6):
    """
    Example 1.19:
    Compares the numerical and symbolic integration of 2π-periodic functions
    when calculating N-point trapezium rule integration, indicating both the
    computed trapezoidal-rule errors |T_N − I| and 2π|a_N|, the latter of
    which is the modulus of the first term in the theoretical error expansion
    given in eqn (1.91) in the notes.
    
    Parameters
    ----------
    f_sumb : sympy.core.function
        2π-perioic symbolic function of which we are integrating.
    c : Integer
        Indicates number of grid points we compare in our intgration,
        iterating N = 2^i for i = 2^0, 2^1, ..., 2^c. The default value is 6.
    """
    from sympy import integrate, lambdify
    from MATH3474 import trapezium_rule

    fig, ax = plot_setup('$N$', 'Error', y_log=True, x_log=True, scale=s)
    Exact = Numeric(integrate(f_symb, (x, -pi, pi)))
    f = lambdify(x, f_symb, 'numpy')
    
    def a(m):
        return Numeric(integrate(f_symb*cos(m*x), (x, -pi, pi))/pi)
    
    NN_arr = np.arange(2, 1+2**c, 1)[::2]
    Err = np.zeros(len(NN_arr))
    
    for i, NN in enumerate(NN_arr):
        Err[i] = trapezium_rule(f, -pi, pi, NN)-Exact
    
    TMI = 2*pi*np.array([abs(a(NN)) for NN in NN_arr])
    TYP = 2*pi*np.array([1/(NN**2) for NN in NN_arr])
    labs = ['$|\\mathcal{{T}}_N-\\mathcal{{I}}|$', '$2\\pi|a_N|$',
            '$N^{{-2}}$']
    for vals, col, lab in zip([Err, TMI, TYP], ['bx', 'ro', 'g:'], labs):
        ax.plot(NN_arr, vals, col, label=lab, fillstyle='none')
    
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2),
              shadow=True, ncol=3, fontsize=s*16)
    pt.show()

if __name__ == '__main__':
    s = .5 #scale of plots on screen
    from sympy import exp #, cos, sin, tan
    from sympy import N as Numeric
    from sympy.abc import x

    a, b, n = -1, 1, 1 #a, b, and n are parameter values as defined in notes
    f_expr = exp(x) # function on x in [a, b] used in the following examples

#    #Example 1.11a: Chebyshev Polynomials
#    Example_11a(a, b, n)
#    
#    #Example 1.11: Chebyshev Grid Points and Roots to T_n(x)=0
#    Example_11(a, b, n)
#
#    #Example 1.14: Chebyshev Interpolating Polynomial
#    Example_14(f_expr, a, b, n)
#    Example_14b(f_expr, a, b, n)
#
#    #Example 1.17: Near-Minimax Interpolation
#    Example_17(f_expr, a, b, n)
#    
#    #Example 1.18: Forced-Oscillation Approximation
#    Example_18(f_expr, a, b, n)
#
#    #Example 1.19: Spectrally accurate trapezoidal-rule integration
#    # Note that depending on the chosen f(x) this requires a little while to
#    #execute!
#    f_expr2 = exp(cos(x))
#    #f_expr2 must be a 2pi-periodic function!!!
#    assert f_expr2.subs(x, pi)==f_expr2.subs(x, -pi)
#    Example_19(f_expr2)