Custom TorchScript Operators Example
November 29, 2018 ยท View on GitHub
This repository contains examples for writing, compiling and using custom TorchScript operators. See here for the accompanying tutorial.
Contents
There a few monuments in this repository you can visit. They are described in context in the tutorial, which you are encouraged to read. These monuments are:
example_app/warp_perspective/op.cpp: The custom operator implementation,example_app/main.cpp: An example application that loads and executes a serialized TorchScript model, which uses the custom operator, in C++,script.py: Example of using the custom operator in a scripted model,trace.py: Example of using the custom operator in a traced model,eager.py: Example of using the custom operator in vanilla eager PyTorch,load.py: Example of usingtorch.utils.cpp_extension.loadto build the custom operator,load.py: Example of usingtorch.utils.cpp_extension.load_inlineto build the custom operator,setup.py: Example of using setuptools to build the custom operator,test_setup.py: Example of using the custom operator built usingsetup.py.
To execute the C++ application, first run script.py to serialize a TorchScript
model to a file called example.pt, then pass that file to the
example_app/build/example_app binary.
Setup
For the smoothest experience when trying out these examples, we recommend
building a docker container from this repository's Dockerfile. This will give
you a clean, isolated Ubuntu Linux environment in which we guarantee everything
to work perfectly. These steps should get you started:
$ git clone https://github.com/pytorch/extension-script
$ cd extension-script
$ docker build -t extension-script .
$ docker run -v $PWD:/home -it extension-script
$ root@2f00feefe46a:/home# source /activate # Activate the Conda environment
$ cd example_app && mkdir build && cd build
$ cmake -DCMAKE_PREFIX_PATH=/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found torch: /libtorch/lib/libtorch.so
-- Configuring done
-- Generating done
-- Build files have been written to: /home/example_app/build
$ make -j
Scanning dependencies of target warp_perspective
[ 25%] Building CXX object warp_perspective/CMakeFiles/warp_perspective.dir/op.cpp.o
[ 50%] Linking CXX shared library libwarp_perspective.so
[ 50%] Built target warp_perspective
Scanning dependencies of target example_app
[ 75%] Building CXX object CMakeFiles/example_app.dir/main.cpp.o
[100%] Linking CXX executable example_app
[100%] Built target example_app
This will create a shared library under
/home/example_app/build/warp_perspective/libwarp_perspective.so containing the
custom operator defined in example_app/warp_perspective/op.cpp. Then, you can
run the examples, e.g.:
(base) root@2f00feefe46a:/home# python script.py
graph(%x.1 : Dynamic
%y : Dynamic) {
%20 : int = prim::Constant[value=1]()
%16 : int[] = prim::Constant[value=[0, -1]]()
%14 : int = prim::Constant[value=6]()
%2 : int = prim::Constant[value=0]()
%7 : int = prim::Constant[value=42]()
%z.1 : int = prim::Constant[value=5]()
%z.2 : int = prim::Constant[value=10]()
%13 : int = prim::Constant[value=3]()
%4 : Dynamic = aten::select(%x.1, %2, %2)
%6 : Dynamic = aten::select(%4, %2, %2)
%8 : Dynamic = aten::eq(%6, %7)
%9 : bool = prim::TensorToBool(%8)
%z : int = prim::If(%9)
block0() {
-> (%z.1)
}
block1() {
-> (%z.2)
}
%17 : Dynamic = aten::eye(%13, %14, %2, %16)
%x : Dynamic = my_ops::warp_perspective(%x.1, %17)
%19 : Dynamic = aten::matmul(%x, %y)
%21 : Dynamic = aten::add(%19, %z, %20)
return (%21);
}
tensor([[11.6196, 12.0056, 11.6122, 12.9298, 7.0649],
[ 8.5063, 9.0621, 9.9925, 6.3741, 8.9668],
[12.5898, 6.5872, 8.1511, 10.0806, 11.9829],
[ 4.9142, 11.6614, 15.7161, 17.0538, 11.7243],
[10.0000, 10.0000, 10.0000, 10.0000, 10.0000],
[10.0000, 10.0000, 10.0000, 10.0000, 10.0000],
[10.0000, 10.0000, 10.0000, 10.0000, 10.0000],
[10.0000, 10.0000, 10.0000, 10.0000, 10.0000]])