Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions src/base/random_sample_infinilm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#ifndef INFINI_OPS_BASE_RANDOM_SAMPLE_INFINILM_H_
#define INFINI_OPS_BASE_RANDOM_SAMPLE_INFINILM_H_

#include <cassert>
#include <cstdint>
#include <limits>

#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

class RandomSampleInfinilm : public Operator<RandomSampleInfinilm> {
public:
RandomSampleInfinilm(const Tensor logits, float random_val, float topp,
int64_t topk, float temperature, Tensor out)
: dtype_{logits.dtype()},
out_dtype_{out.dtype()},
n_{logits.size(0)},
logits_stride_{logits.stride(0)},
random_val_{random_val},
topp_{topp},
topk_{topk},
temperature_{temperature} {
assert(logits.ndim() == 1 && "`RandomSampleInfinilm` requires 1D logits");
assert(n_ > 0 && "`RandomSampleInfinilm` requires non-empty logits");
assert(logits.stride(0) == 1 &&
"`RandomSampleInfinilm` requires contiguous logits");
assert(out.numel() == 1 && "`RandomSampleInfinilm` requires scalar output");
assert(IsFloatDtype(dtype_) &&
"`RandomSampleInfinilm` requires floating-point logits");
assert(IsIntDtype(out_dtype_) &&
"`RandomSampleInfinilm` requires integer output");
assert(topk > 0 && "`RandomSampleInfinilm` requires `topk > 0`");
assert(topk <= std::numeric_limits<int>::max() &&
"`RandomSampleInfinilm` requires `topk` to fit in int");
}

virtual void operator()(const Tensor logits, float random_val, float topp,
int64_t topk, float temperature,
Tensor out) const = 0;

protected:
static bool IsFloatDtype(DataType dtype) {
return dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kFloat32 || dtype == DataType::kFloat64;
}

static bool IsIntDtype(DataType dtype) {
return dtype == DataType::kInt8 || dtype == DataType::kInt16 ||
dtype == DataType::kInt32 || dtype == DataType::kInt64 ||
dtype == DataType::kUInt8 || dtype == DataType::kUInt16 ||
dtype == DataType::kUInt32 || dtype == DataType::kUInt64;
}

DataType dtype_;

DataType out_dtype_;

Tensor::Size n_{0};

Tensor::Stride logits_stride_{1};

float random_val_{0.0f};

float topp_{0.0f};

int64_t topk_{1};

float temperature_{1.0f};
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/iluvatar/ops/random_sample_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_ILUVATAR_RANDOM_SAMPLE_INFINILM_KERNEL_H_
#define INFINI_OPS_ILUVATAR_RANDOM_SAMPLE_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/iluvatar/caster.cuh"
#include "native/cuda/iluvatar/runtime_.h"
#include "native/cuda/ops/random_sample_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RandomSampleInfinilm, Device::Type::kIluvatar>
: public CudaRandomSampleInfinilm<Runtime<Device::Type::kIluvatar>> {
public:
using CudaRandomSampleInfinilm<
Runtime<Device::Type::kIluvatar>>::CudaRandomSampleInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/metax/ops/random_sample_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_METAX_RANDOM_SAMPLE_INFINILM_KERNEL_H_
#define INFINI_OPS_METAX_RANDOM_SAMPLE_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/metax/caster.cuh"
#include "native/cuda/metax/runtime_.h"
#include "native/cuda/ops/random_sample_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RandomSampleInfinilm, Device::Type::kMetax>
: public CudaRandomSampleInfinilm<Runtime<Device::Type::kMetax>> {
public:
using CudaRandomSampleInfinilm<
Runtime<Device::Type::kMetax>>::CudaRandomSampleInfinilm;
};

} // namespace infini::ops

#endif
30 changes: 30 additions & 0 deletions src/native/cuda/moore/ops/random_sample_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef INFINI_OPS_MOORE_RANDOM_SAMPLE_INFINILM_KERNEL_H_
#define INFINI_OPS_MOORE_RANDOM_SAMPLE_INFINILM_KERNEL_H_

#include <utility>

// clang-format off
#include <musa_runtime.h>
// clang-format on

#include "native/cuda/moore/caster.cuh"
#include "native/cuda/moore/runtime_.h"
#include "native/cuda/ops/random_sample_infinilm/kernel.h"

namespace infini::ops {

struct MooreRandomSampleInfinilmBackend : Runtime<Device::Type::kMoore> {
static constexpr int max_block_size = 1024;
};

template <>
class Operator<RandomSampleInfinilm, Device::Type::kMoore>
: public CudaRandomSampleInfinilm<MooreRandomSampleInfinilmBackend> {
public:
using CudaRandomSampleInfinilm<
MooreRandomSampleInfinilmBackend>::CudaRandomSampleInfinilm;
};

} // namespace infini::ops

#endif
22 changes: 22 additions & 0 deletions src/native/cuda/nvidia/ops/random_sample_infinilm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef INFINI_OPS_NVIDIA_RANDOM_SAMPLE_INFINILM_KERNEL_H_
#define INFINI_OPS_NVIDIA_RANDOM_SAMPLE_INFINILM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/random_sample_infinilm/kernel.h"

namespace infini::ops {

template <>
class Operator<RandomSampleInfinilm, Device::Type::kNvidia>
: public CudaRandomSampleInfinilm<Runtime<Device::Type::kNvidia>> {
public:
using CudaRandomSampleInfinilm<
Runtime<Device::Type::kNvidia>>::CudaRandomSampleInfinilm;
};

} // namespace infini::ops

#endif
Loading
Loading