|
|
#include <torch/library.h> |
|
|
|
|
|
#include "registration.h" |
|
|
#include "torch_binding.h" |
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
|
ops.def("relu(Tensor input) -> Tensor"); |
|
|
ops.impl("relu", torch::kMPS, mps_relu); |
|
|
} |
|
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|
|