Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there any method to generate a piecewise function for tensors in pytorch?

Tags:

pytorch

enter image description here

I want to get a piecewise function like this for tensors in pytorch. But I don't know how to define it. I use a very stupid method to do it, but it seems not to work in my code.

    def trapezoid(self, X):
        Y = torch.zeros(X.shape)
        Y[X % (2 * pi) < (0.5 * pi)] = (X[X % (2 * pi) < (0.5 * pi)] % (2 * pi)) * 2 / pi
        Y[(X % (2 * pi) >= (0.5 * pi)) & (X % (2 * pi) < 1.5 * pi)] = 1.0
        Y[X % (2 * pi) >= (1.5 * pi)] = (X[X % (2 * pi) >= (1.5 * pi)] % (2 * pi)) * (-2 / pi) + 4
        return Y

could do you help me find out how to design the function trapezoid, so that for tensor X, I can get the result directly using trapezoid(X)?

like image 421
向宇桂 Avatar asked Nov 19 '25 12:11

向宇桂


1 Answers

Since your function has period 2π we can focus on [0,2π]. Since it's piecewise linear, it's possible to express it as a mini ReLU network on [0,2π] given by:

trapezoid(x) = 1 - relu(x-1.5π)/0.5π - relu(0.5π-x)/0.5π

Thus, we can code the whole function in Pytorch like so:

import torch
import torch.nn.functional as F
from torch import tensor 
from math import pi

def trapezoid(X):
  # Left corner position, right corner position, height
  a, b, h = tensor(0.5*pi), tensor(1.5*pi), tensor(1.0)

  # Take remainder mod 2*pi for periodicity
  X = torch.remainder(X,2*pi)

  return h - F.relu(X-b)/a - F.relu(a-X)/a

Plotting to double check produces the correct picture:

import matplotlib.pyplot as plt

X = torch.linspace(-10,10,1000)
Y = trapezoid(X)
plt.plot(X,Y)
plt.title('Pytorch Trapezoid Function')

enter image description here

like image 187
Christian Bueno Avatar answered Nov 22 '25 04:11

Christian Bueno



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!