#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Joseph Elmes : NERC Funded PhD Researcher in Applied Mathematics
University of Leeds : Leeds LS2 9JT : ml14je@leeds.ac.uk

Python 3.7 : Sat Aug 17 14:19:52 2019
"""

import numpy as np
import matplotlib.pyplot as pt
from functions import plot_setup, newton_raphson
from math import pi

def theta_solver(f, x_arr, dx, init_vals, theta):
    """
    Solves dy/dx = f(x, y) using Euler method.
    """
    from sympy import lambdify
    from sympy.abc import x, y

    g = lambdify((x, y), f, 'numpy')
    def g1(y11, y12, x_val):
        return (y12-y11)-dx*(theta*g(x_arr[i], y12)+\
                        (1-theta)*g(x_val, y11))
    solution = np.zeros((len(x_arr), len(init_vals)))
    solution[0] = init_vals

    for i in range(1, len(x_arr)):
        for j in range(len(init_vals)):
            y1 = solution[i-1][j]
            h = lambda y2: g1(y1, y2, x_arr[i])
            solution[i, j] = newton_raphson(h, y1, max_iteration=5)
    
    fig, ax = plot_setup('$x$', '$y$', scale=0.75)
    for i in range(len(init_vals)):
        ax.plot(x_arr, solution[:,i], 'r', linewidth=0.4)
    
if __name__ == '__main__':
    from sympy import sin, exp, cos, tan
    from sympy.abc import x, y

    f = exp(sin(x))
    y0, yn, dy = -2*pi, 2*pi, pi/4
    init_vals = np.arange(y0, yn+dy, dy)
    x0, xn, dx = 0, 50, 0.05#2*pi, 0.01
    x_arr = np.round(np.linspace(x0, xn, int(xn/dx)+1), 5)
    
    theta_solver(f, x_arr, dx, init_vals, 0.5)