Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/base/softmax_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#ifndef INFINI_OPS_BASE_SOFTMAX_INFINILM_H_
#define INFINI_OPS_BASE_SOFTMAX_INFINILM_H_

#include <cassert>
#include <optional>

#include "operator.h"

namespace infini::ops {

class SoftmaxInfinilm : public Operator<SoftmaxInfinilm> {
public:
SoftmaxInfinilm(const Tensor input, const int64_t dim,
const std::optional<DataType> dtype, Tensor out)
: input_shape_{input.shape()},
input_strides_{input.strides()},
input_type_{input.dtype()},
out_shape_{out.shape()},
out_strides_{out.strides()},
out_type_{out.dtype()},
dim_{dim < 0 ? dim + static_cast<int64_t>(input.ndim()) : dim},
dtype_{dtype},
ndim_{out.ndim()},
dim_size_{out.size(dim_)},
row_count_{out.numel() / dim_size_},
device_index_{out.device().index()} {
assert(input_shape_ == out_shape_ &&
"`SoftmaxInfinilm` input and output shapes must match");
assert(dim_ >= 0 && dim_ < static_cast<int64_t>(ndim_) &&
"`SoftmaxInfinilm` dim out of range");
assert(!dtype_.has_value() || dtype_.value() == out_type_);
assert(input_type_ == out_type_ &&
"`SoftmaxInfinilm` input and output dtypes must match");
assert(!out.HasBroadcastDim() &&
"`SoftmaxInfinilm` output must not have broadcasted dimensions");
}

virtual void operator()(const Tensor input, const int64_t dim,
const std::optional<DataType> dtype,
Tensor out) const = 0;

protected:
Tensor::Shape input_shape_;

Tensor::Strides input_strides_;

DataType input_type_;

Tensor::Shape out_shape_;

Tensor::Strides out_strides_;

DataType out_type_;

int64_t dim_{};

std::optional<DataType> dtype_{};

Tensor::Size ndim_{0};

Tensor::Size dim_size_{0};

Tensor::Size row_count_{0};

int device_index_{0};
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/softmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_SOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_SOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/softmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<SoftmaxInfinilm, Device::Type::kIluvatar>
: public CudaSoftmaxInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaSoftmaxInfinilm<
Runtime<Device::Type::kIluvatar>>::CudaSoftmaxInfinilm;
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/metax/ops/softmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_METAX_SOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_SOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/softmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<SoftmaxInfinilm, Device::Type::kMetax>
: public CudaSoftmaxInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaSoftmaxInfinilm<Runtime<Device::Type::kMetax>>::CudaSoftmaxInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/moore/ops/softmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_MOORE_SOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_SOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/polyfills.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/softmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<SoftmaxInfinilm, Device::Type::kMoore>
: public CudaSoftmaxInfinilm<Runtime<Device::Type::kMoore>> {
public:
using CudaSoftmaxInfinilm<Runtime<Device::Type::kMoore>>::CudaSoftmaxInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/softmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_SOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_SOFTMAX_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/softmax_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<SoftmaxInfinilm, Device::Type::kNvidia>
: public CudaSoftmaxInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaSoftmaxInfinilm<
Runtime<Device::Type::kNvidia>>::CudaSoftmaxInfinilm;
};

} // namespace infini::ops

#endif
113 changes: 113 additions & 0 deletions src/native/cuda/ops/softmax_infinilm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#ifndef INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_CUH_
#define INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_CUH_

#include <cfloat>
#include <cmath>
#include <cstddef>
#include <cub/block/block_reduce.cuh>

#include "native/cuda/caster.cuh"
#include "native/cuda/kernel_commons.cuh"

namespace infini::ops {

namespace {

struct SoftmaxInfinilmMaxOp {
__device__ __forceinline__ float operator()(float a, float b) const {
return a > b ? a : b;
}
};

template <unsigned int block_size>
__device__ __forceinline__ float BlockMax(float value) {
using BlockReduce = cub::BlockReduce<float, block_size>;
__shared__ typename BlockReduce::TempStorage temp_storage;
return BlockReduce(temp_storage).Reduce(value, SoftmaxInfinilmMaxOp());
}

template <unsigned int block_size>
__device__ __forceinline__ float BlockSum(float value) {
using BlockReduce = cub::BlockReduce<float, block_size>;
__shared__ typename BlockReduce::TempStorage temp_storage;
return BlockReduce(temp_storage).Sum(value);
}

__device__ __forceinline__ size_t SoftmaxInfinilmRowOffset(
size_t row, size_t ndim, size_t dim, const size_t* __restrict__ shape,
const ptrdiff_t* __restrict__ strides) {
size_t offset = 0;
for (size_t axis = ndim; axis > 0; --axis) {
size_t i = axis - 1;
if (i == dim) {
continue;
}
size_t coord = row % shape[i];
row /= shape[i];
offset += coord * strides[i];
}
return offset;
}

} // namespace

template <unsigned int block_size, Device::Type kDev, typename T>
__global__ void SoftmaxInfinilmKernel(
T* __restrict__ out, const T* __restrict__ input,
const size_t* __restrict__ shape, const ptrdiff_t* __restrict__ out_strides,
const ptrdiff_t* __restrict__ input_strides, size_t row_count,
size_t dim_size, size_t ndim, size_t dim) {
size_t row = blockIdx.x + blockIdx.y * gridDim.x;
if (row >= row_count) {
return;
}

size_t input_base =
SoftmaxInfinilmRowOffset(row, ndim, dim, shape, input_strides);
size_t out_base =
SoftmaxInfinilmRowOffset(row, ndim, dim, shape, out_strides);
ptrdiff_t input_dim_stride = input_strides[dim];
ptrdiff_t out_dim_stride = out_strides[dim];

float thread_max = -FLT_MAX;
for (size_t i = threadIdx.x; i < dim_size; i += block_size) {
float value = Caster<kDev>::template Cast<float>(
input[input_base + i * input_dim_stride]);
thread_max = thread_max > value ? thread_max : value;
}

float block_max = BlockMax<block_size>(thread_max);
__shared__ float max_value;
if (threadIdx.x == 0) {
max_value = block_max;
}
__syncthreads();

float thread_sum = 0.0f;
for (size_t i = threadIdx.x; i < dim_size; i += block_size) {
float value = Caster<kDev>::template Cast<float>(
input[input_base + i * input_dim_stride]);
float exp_value = expf(value - max_value);
thread_sum += exp_value;
out[out_base + i * out_dim_stride] =
Caster<kDev>::template Cast<T>(exp_value);
}

float block_sum = BlockSum<block_size>(thread_sum);
__shared__ float sum_value;
if (threadIdx.x == 0) {
sum_value = block_sum;
}
__syncthreads();

for (size_t i = threadIdx.x; i < dim_size; i += block_size) {
float value =
Caster<kDev>::template Cast<float>(out[out_base + i * out_dim_stride]);
out[out_base + i * out_dim_stride] =
Caster<kDev>::template Cast<T>(value / sum_value);
}
}

} // namespace infini::ops

#endif
93 changes: 93 additions & 0 deletions src/native/cuda/ops/softmax_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#ifndef INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_H_
#define INFINI_OPS_CUDA_SOFTMAX_INFINILM_KERNEL_H_

#include <algorithm>
#include <cstddef>
#include <cstring>
#include <vector>

#include "base/softmax_infinilm.h"
#include "common/generic_utils.h"
#include "data_type.h"
#include "dispatcher.h"
#include "native/cuda/kernel_commons.cuh"
#include "native/cuda/ops/softmax_infinilm/kernel.cuh"
#include "native/cuda/runtime_utils.h"

namespace infini::ops {

template <typename Backend>
class CudaSoftmaxInfinilm : public SoftmaxInfinilm {
public:
CudaSoftmaxInfinilm(const Tensor input, const int64_t dim,
const std::optional<DataType> dtype, Tensor out)
: SoftmaxInfinilm{input, dim, dtype, out} {
size_t shape_size = ndim_ * sizeof(*d_shape_);
size_t strides_size = ndim_ * sizeof(*d_input_strides_);
const size_t metadata_size = shape_size + 2 * strides_size;
std::vector<std::byte> metadata(metadata_size);

Backend::Malloc((void**)&d_metadata_, metadata_size);

size_t offset = 0;
d_shape_ = reinterpret_cast<Tensor::Size*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size);
offset += shape_size;

d_input_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size);
offset += strides_size;

d_out_strides_ = reinterpret_cast<Tensor::Stride*>(d_metadata_ + offset);
std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size);

Backend::Memcpy(d_metadata_, metadata.data(), metadata_size,
Backend::MemcpyHostToDevice);
}

~CudaSoftmaxInfinilm() { Backend::Free(d_metadata_); }

void operator()(const Tensor input, const int64_t dim,
const std::optional<DataType> dtype,
Tensor out) const override {
(void)dim;
(void)dtype;
auto cuda_stream =
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);
int block_size = std::min(
RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize(), 1024);

DispatchFunc<AllFloatTypes, List<128, 256, 512, 1024>>(
{static_cast<int64_t>(out_type_), block_size},
[&](auto list_tag) {
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
constexpr int kBlockSize = ListGet<1>(list_tag);

const unsigned grid_x =
static_cast<unsigned>(std::min<Tensor::Size>(row_count_, 65535));
const unsigned grid_y = static_cast<unsigned>(
utils::CeilDiv(row_count_, static_cast<Tensor::Size>(grid_x)));

SoftmaxInfinilmKernel<kBlockSize, Backend::kDeviceType, T>
<<<dim3(grid_x, grid_y), kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()),
reinterpret_cast<const T*>(input.data()), d_shape_,
d_out_strides_, d_input_strides_, row_count_, dim_size_,
ndim_, dim_);
},
"CudaSoftmaxInfinilm::operator()");
}

private:
std::byte* d_metadata_{nullptr};

Tensor::Size* d_shape_{nullptr};

Tensor::Stride* d_input_strides_{nullptr};

Tensor::Stride* d_out_strides_{nullptr};
};

} // namespace infini::ops

#endif
Loading
Loading