Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is pytorch slower on m1 pro 10 core vs linux CPU?

I ran the following benchmark from here .

#!/usr/bin/env python3

import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')

I got

mul_sum(x, x):  1065.9 us
bmm(x, x):      134.5 us

on mac

and

mul_sum(x, x):   52.3 us
bmm(x, x):      120.1 us

on linux CPU

I'm seeing a huge performance difference, is this expected?

I first noticed this difference on a more serious program, and am trying to replicate it here.

like image 443
piedpiper Avatar asked Feb 02 '26 19:02

piedpiper


1 Answers

John Zavialov answer goes over the general issues, I'll briefly list them here. This portion is basically going to be summarizing that answer and get into how to speed it up.

1.Optimization Progress: PyTorch's adaptation to the Apple Silicon architecture is still undergoing refinement and is not as mature as Linux's setup.

Architecture-Specific Tuning: PyTorch is setup for specific Architectures, which means that it may not have as solid performance on each system.

Instruction Set Variation: The instruction set architecture significantly impacts the execution efficiency of various operations with ARM-based systems(M1 Pros) being distinct then x86_84 (Linux) which can lead to big differences in performance

In order to fix this or speed this up. Mac has added new Metal Performance Shader. If you activate this for Pytorch on Mac you should see a performance boost. You can see installation instructions in the link, I put down a code example to test and activate it:


import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

like image 152
Lucas Hendren Avatar answered Feb 04 '26 15:02

Lucas Hendren



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!