#ifndef MLLM_CPUABC_H
#define MLLM_CPUABC_H
#include "Op.hpp"
#include "CPUBackend.hpp"
namespace mllm {
class CPUAbc final : public Op {
public:
CPUAbc(Backend *bn, string opName, int param1, bool param2, int threadCount);
virtual ~CPUAbc() = default;
//计算outputs的Tensor大小
virtual ErrorCode reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
//进行计算,对outputs赋值
virtual ErrorCode execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
//加载权重参数(非必须重载)
virtual ErrorCode load(AbstructLoader &loader) override;
//释放权重参数, 与load成对出现(非必须重载)
virtual ErrorCode free(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
//对outputs(inputs)进行内存管理(非必须重载)
virtual ErrorCode setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
private:
// Op的参数
int param1_;
bool param2_;
int thread_count = 4;
};
class CPUAbcCreator : public CPUBackend::Creator {
public:
virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const {
//在此处读取CPUAbc的参数
//OpParam为vector<float>
int param1 = (int)op_param["param1"];
bool param2 = (bool)op_param["param2"];
return new CPUAbc(bn, name, param1, param2, threadCount);
}
};
} // namespace mllm
#endif // MLLM_CPUABC_H