#!/usr/bin/env python
# coding: utf-8

# In[17]:


#example problem
#finite difference method for BVPs
#y'' + y' = x
#y(a) = ya, y'(b) = dyb


# In[18]:


a = 0
b = 15
ya = 1
dyb = 2


# In[51]:


n = 5 #number of grid points
h = (b-a) / n #spacing between points
from numpy import zeros
A = zeros((n+1, n+1))
v = zeros(n+1)

A[0, 0] = 1
v[0] = ya
for i in range(1, n):
    A[i, i-1] = (1 / h**2) - (1/ (2*h))
    A[i, i] = (-2 / h**2)
    A[i, i+1] = (1 / h**2) + (1/ (2*h))
    v[i] = i*h
A[n, n-2] = 1 / (2*h)
A[n, n-1] = -2 / (h)
A[n, n] = 3 / (2*h)
v[n] = dyb

from numpy.linalg import solve
y1 = solve(A, v)

from numpy import linspace
x1 = linspace(a, b, n+1)


# In[54]:


A, v


# In[33]:


n = 10 #number of grid points
h = (b-a) / n #spacing between points
from numpy import zeros
A = zeros((n+1, n+1))
v = zeros(n+1)
A[0, 0] = 1
v[0] = ya
for i in range(1, n):
    A[i, i-1] = (1 / h**2) - (1/ (2*h))
    A[i, i] = (-2 / h**2)
    A[i, i+1] = (1 / h**2) + (1/ (2*h))
    v[i] = i*h
A[n, n-2] = 1 / (2*h)
A[n, n-1] = -2 / (h)
A[n, n] = 3 / (2*h)
v[n] = dyb
y2 = solve(A, v)
x2 = linspace(a, b, n+1)


# In[34]:


n = 100
h = (b-a) / n
from numpy import zeros
A = zeros((n+1, n+1))
v = zeros(n+1)
A[0, 0] = 1
v[0] = ya
for i in range(1, n):
    A[i, i-1] = (1 / h**2) - (1/ (2*h))
    A[i, i] = (-2 / h**2)
    A[i, i+1] = (1 / h**2) + (1/ (2*h))
    v[i] = i*h
A[n, n-2] = 1 / (2*h)
A[n, n-1] = -2 / (h)
A[n, n] = 3 / (2*h)
v[n] = dyb
y3 = solve(A, v)
x3 = linspace(a, b, n+1)


# In[39]:


n = 1000 #number of grid points
h = (b-a) / n #spacing between points
from numpy import zeros
A = zeros((n+1, n+1))
v = zeros(n+1)
A[0, 0] = 1
v[0] = ya
for i in range(1, n):
    A[i, i-1] = (1 / h**2) - (1/ (2*h))
    A[i, i] = (-2 / h**2)
    A[i, i+1] = (1 / h**2) + (1/ (2*h))
    v[i] = i*h
A[n, n-2] = 1 / (2*h)
A[n, n-1] = -2 / (h)
A[n, n] = 3 / (2*h)
v[n] = dyb
y4 = solve(A, v)
x4 = linspace(a, b, n+1)


# In[40]:


from matplotlib.pyplot import plot, legend
plot(x1, y1, 's', x2, y2, 'o', x3, y3, '-', x4, y4, '-')
legend(['n=5', 'n=10', 'n=100', 'n=1000'])


# In[44]:


#check if ODE is reasonably approximated
ddy = (y4[2:] - 2*y4[1:-1] + y4[:-2]) / (h**2)
dy = (y4[2:] - y4[:-2]) / (2*h)
plot(x4[1:-1], ddy+dy, x4[1:-1], x4[1:-1])
#yes -- y'' + y' is close to x


# In[50]:


#check if the boundary values are approximated
y4[0], (y4[-1] - y4[-2])/h
#should be ya and approximately dyb respectively


# In[45]:


#solution with SciPy function
from scipy.integrate import solve_bvp
f = lambda x, z: [z[1], x - z[1]]
bc = lambda za, zb: [za[0] - ya, zb[1] - dyb]
x = [a, b]
z = zeros((2, 2))
z[:, 0] = [ya, dyb] #guess for z[0] and z[1] at a
z[:, 1] = [ya+(b-a)*dyb, dyb] #guess for z[0] and z[1] at b
s = solve_bvp(f, bc, x, z)


# In[49]:


plot(s.x, s.y[0, :], x4, y4)
legend(['solve_bvp', 'finite difference (n=1000)'])


# In[55]:


s.y[0, 0], s.y[1, -1]
#should be approximately ya and dyb respectively


# In[38]:


#example adapted from solve_bvp documentation
import numpy as np
def fun(x, y):
    return np.vstack((y[1], -y[0]))
def bc(ya, yb):
    return np.array([yb[0], ya[1] - 1])    
x = np.linspace(0, 1, 5)
y = np.zeros((2, x.size))
y[0, 1] = 1
y[0, 3] = -1
sol = solve_bvp(fun, bc, x, y)
x_plot = np.linspace(0, 1, 100)
y_plot = sol.sol(x_plot)[0]
import matplotlib.pyplot as plt
plt.plot(x_plot, y_plot)
plt.xlabel("x")
plt.ylabel("y")
plt.show()


# In[51]:


#problem from textbook -- beam deflection
EI = 1.8E7
w0 = 15E3
L = 5
f = lambda x, z: [z[1], (w0/(2*EI))*(L*x - x*x)*((1+z[1]**2)**(3/2))]
bc = lambda za, zb: [za[0], zb[0]]
x = np.linspace(0, L, 5)
z = np.zeros((2, x.size))
s = solve_bvp(f, bc, x, z)


# In[55]:


x_plot = np.linspace(0, L, 100)
y_plot = s.sol(x_plot)[0] #approximates the solution at the generated points
plt.plot(s.x, s.y[0], 's', x_plot, y_plot)
plt.xlabel("x")
plt.ylabel("y")
plt.show()


# In[56]:


#displacement and slope at the beam midpoint
s.sol(L/2)


# In[ ]:




