I have a function Black_Cox() which calls other functions as shown below:
import numpy as np
from scipy import stats
# Parameters
D = 100
r = 0.05
γ = 0.1
# Normal CDF
N = lambda x: stats.norm.cdf(x)
H = lambda V, T, L, σ: np.exp(-r*T) * N( (np.log(V/L) + (r-0.5*σ**2)*T) / (σ*np.sqrt(T)) )
# Black-Scholes
def C_BS(V, K, T, σ):
d1 = (np.log(V/K) + (r + 0.5*σ**2)*T ) / ( σ*np.sqrt(T) )
d2 = d1 - σ*np.sqrt(T)
return V*N(d1) - np.exp(-r*T)*K*N(d2)
def BL(V, T, D, L, σ):
return L * H(V, T, L, σ) - L * (L/V)**(2*r/σ**2-1) * H(L**2/V, T, L, σ) \
+ C_BS(V, L, T, σ) - (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, L, T, σ) \
- C_BS(V, D, T, σ) + (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, D, T, σ)
def Bb(V, T, C, γ, σ, a):
b = (np.log(C/V) - γ*T) / σ
μ = (r - a - 0.5*σ**2 - γ) / σ
m = np.sqrt(μ**2 + 2*r)
return C*np.exp(b*(μ-m)) * ( N((b-m*T)/np.sqrt(T)) + np.exp(2*m*b)*N((b+m*T)/np.sqrt(T)) )
def Black_Cox(V, T, C=160, σ=0.1, a=0):
return np.exp(γ*T)*BL(V*np.exp(-γ*T), T, D*np.exp(-γ*T), C*np.exp(-γ*T), σ) + Bb(V, T, C, γ, σ, a)
I need to work with the derivative of the Black_Cox function w.r.t. V. More precisely, I need to evaluate this derivative across thousands of paths where I change other arguments, find the derivative and evaluate at some V.
What is the best way to proceed?
Should I use sympy to find this derivative and then evaluate at my V of choice, as I would do in Mathematica: D[BlackCox[V, 10, 100, 160], V] /. V -> 180, or
Should I just use jax?
If sympy, how would you advise me to do this?
With jax I understand that I need to do the following imports:
import jax.numpy as np
from jax.scipy import stats
from jax import grad
and re-evaluate my functions before getting the gradient:
func = lambda x: Black_Cox(x,10,160,0.1)
grad(func)(180.0)
If I will still need to work with the numpy version of the functions, will I have to create 2 instances of each function(s) or is there an elegant way to duplicate a function for jax purposes?
Jax does not provide any built-in way to recompile a numpy function using jax versions of numpy and scipy. But you can use a snippet like the following one to do it automatically:
import inspect
from functools import wraps
import numpy as np
import jax.numpy
def replace_globals(func, globals_):
"""Recompile a function with replaced global values."""
namespace = func.__globals__.copy()
namespace.update(globals_)
source = inspect.getsource(func)
exec(source, namespace)
return wraps(func)(namespace[func.__name__])
It works like this:
def numpy_func(N):
return np.arange(N) ** 2
jax_func = replace_globals(numpy_func, {"np": jax.numpy})
Now you can evaluate the numpy version:
numpy_func(10)
# array([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81])
and the jax version:
jax_func(10)
# DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
Just make certain you replace all the relevant global variables when you wrap your more complicated function.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With