普通文本  |  497行  |  17.67 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// JNI wrapper for the TextClassifier.

#include "textclassifier_jni.h"

#include <jni.h>
#include <type_traits>
#include <vector>

#include "text-classifier.h"
#include "util/base/integral_types.h"
#include "util/java/scoped_local_ref.h"
#include "util/java/string_utils.h"
#include "util/memory/mmap.h"
#include "util/utf8/unilib.h"

using libtextclassifier2::AnnotatedSpan;
using libtextclassifier2::AnnotationOptions;
using libtextclassifier2::ClassificationOptions;
using libtextclassifier2::ClassificationResult;
using libtextclassifier2::CodepointSpan;
using libtextclassifier2::JStringToUtf8String;
using libtextclassifier2::Model;
using libtextclassifier2::ScopedLocalRef;
using libtextclassifier2::SelectionOptions;
using libtextclassifier2::TextClassifier;
#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
using libtextclassifier2::UniLib;
#endif

namespace libtextclassifier2 {

using libtextclassifier2::CodepointSpan;

namespace {

std::string ToStlString(JNIEnv* env, const jstring& str) {
  std::string result;
  JStringToUtf8String(env, str, &result);
  return result;
}

jobjectArray ClassificationResultsToJObjectArray(
    JNIEnv* env,
    const std::vector<ClassificationResult>& classification_result) {
  const ScopedLocalRef<jclass> result_class(
      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"),
      env);
  if (!result_class) {
    TC_LOG(ERROR) << "Couldn't find ClassificationResult class.";
    return nullptr;
  }
  const ScopedLocalRef<jclass> datetime_parse_class(
      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env);
  if (!datetime_parse_class) {
    TC_LOG(ERROR) << "Couldn't find DatetimeResult class.";
    return nullptr;
  }

  const jmethodID result_class_constructor =
      env->GetMethodID(result_class.get(), "<init>",
                       "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR
                       "$DatetimeResult;)V");
  const jmethodID datetime_parse_class_constructor =
      env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");

  const jobjectArray results = env->NewObjectArray(classification_result.size(),
                                                   result_class.get(), nullptr);
  for (int i = 0; i < classification_result.size(); i++) {
    jstring row_string =
        env->NewStringUTF(classification_result[i].collection.c_str());
    jobject row_datetime_parse = nullptr;
    if (classification_result[i].datetime_parse_result.IsSet()) {
      row_datetime_parse = env->NewObject(
          datetime_parse_class.get(), datetime_parse_class_constructor,
          classification_result[i].datetime_parse_result.time_ms_utc,
          classification_result[i].datetime_parse_result.granularity);
    }
    jobject result =
        env->NewObject(result_class.get(), result_class_constructor, row_string,
                       static_cast<jfloat>(classification_result[i].score),
                       row_datetime_parse);
    env->SetObjectArrayElement(results, i, result);
    env->DeleteLocalRef(result);
  }
  return results;
}

template <typename T, typename F>
std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
                                  jclass class_object, F function,
                                  const std::string& method_name,
                                  const std::string& return_java_type) {
  const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
                                            ("()" + return_java_type).c_str());
  if (!method) {
    return std::make_pair(false, T());
  }
  return std::make_pair(true, (env->*function)(object, method));
}

SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
  if (!joptions) {
    return {};
  }

  const ScopedLocalRef<jclass> options_class(
      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"),
      env);
  const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
      env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
      "getLocales", "Ljava/lang/String;");
  if (!status_or_locales.first) {
    return {};
  }

  SelectionOptions options;
  options.locales =
      ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));

  return options;
}

template <typename T>
T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
                          const std::string& class_name) {
  if (!joptions) {
    return {};
  }

  const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
                                             env);
  if (!options_class) {
    return {};
  }

  const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
      env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
      "getLocale", "Ljava/lang/String;");
  const std::pair<bool, jobject> status_or_reference_timezone =
      CallJniMethod0<jobject>(env, joptions, options_class.get(),
                              &JNIEnv::CallObjectMethod, "getReferenceTimezone",
                              "Ljava/lang/String;");
  const std::pair<bool, int64> status_or_reference_time_ms_utc =
      CallJniMethod0<int64>(env, joptions, options_class.get(),
                            &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
                            "J");

  if (!status_or_locales.first || !status_or_reference_timezone.first ||
      !status_or_reference_time_ms_utc.first) {
    return {};
  }

  T options;
  options.locales =
      ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
  options.reference_timezone = ToStlString(
      env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
  options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
  return options;
}

ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
                                                    jobject joptions) {
  return FromJavaOptionsInternal<ClassificationOptions>(
      env, joptions,
      TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions");
}

AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
  return FromJavaOptionsInternal<AnnotationOptions>(
      env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions");
}

CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
                                    CodepointSpan orig_indices,
                                    bool from_utf8) {
  const libtextclassifier2::UnicodeText unicode_str =
      libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);

  int unicode_index = 0;
  int bmp_index = 0;

  const int* source_index;
  const int* target_index;
  if (from_utf8) {
    source_index = &unicode_index;
    target_index = &bmp_index;
  } else {
    source_index = &bmp_index;
    target_index = &unicode_index;
  }

  CodepointSpan result{-1, -1};
  std::function<void()> assign_indices_fn = [&result, &orig_indices,
                                             &source_index, &target_index]() {
    if (orig_indices.first == *source_index) {
      result.first = *target_index;
    }

    if (orig_indices.second == *source_index) {
      result.second = *target_index;
    }
  };

  for (auto it = unicode_str.begin(); it != unicode_str.end();
       ++it, ++unicode_index, ++bmp_index) {
    assign_indices_fn();

    // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
    if (*it > 0xFFFF) {
      ++bmp_index;
    }
  }
  assign_indices_fn();

  return result;
}

}  // namespace

CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
                                      CodepointSpan bmp_indices) {
  return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
}

CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
                                      CodepointSpan utf8_indices) {
  return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
}

jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
  // Get system-level file descriptor from AssetFileDescriptor.
  ScopedLocalRef<jclass> afd_class(
      env->FindClass("android/content/res/AssetFileDescriptor"), env);
  if (afd_class == nullptr) {
    TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
    return reinterpret_cast<jlong>(nullptr);
  }
  jmethodID afd_class_getFileDescriptor = env->GetMethodID(
      afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
  if (afd_class_getFileDescriptor == nullptr) {
    TC_LOG(ERROR) << "Couldn't find getFileDescriptor.";
    return reinterpret_cast<jlong>(nullptr);
  }

  ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
                                  env);
  if (fd_class == nullptr) {
    TC_LOG(ERROR) << "Couldn't find FileDescriptor.";
    return reinterpret_cast<jlong>(nullptr);
  }
  jfieldID fd_class_descriptor =
      env->GetFieldID(fd_class.get(), "descriptor", "I");
  if (fd_class_descriptor == nullptr) {
    TC_LOG(ERROR) << "Couldn't find descriptor.";
    return reinterpret_cast<jlong>(nullptr);
  }

  jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
  return env->GetIntField(bundle_jfd, fd_class_descriptor);
}

jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
  if (!mmap->handle().ok()) {
    return env->NewStringUTF("");
  }
  const Model* model = libtextclassifier2::ViewModel(
      mmap->handle().start(), mmap->handle().num_bytes());
  if (!model || !model->locales()) {
    return env->NewStringUTF("");
  }
  return env->NewStringUTF(model->locales()->c_str());
}

jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
  if (!mmap->handle().ok()) {
    return 0;
  }
  const Model* model = libtextclassifier2::ViewModel(
      mmap->handle().start(), mmap->handle().num_bytes());
  if (!model) {
    return 0;
  }
  return model->version();
}

jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
  if (!mmap->handle().ok()) {
    return env->NewStringUTF("");
  }
  const Model* model = libtextclassifier2::ViewModel(
      mmap->handle().start(), mmap->handle().num_bytes());
  if (!model || !model->name()) {
    return env->NewStringUTF("");
  }
  return env->NewStringUTF(model->name()->c_str());
}

}  // namespace libtextclassifier2

using libtextclassifier2::ClassificationResultsToJObjectArray;
using libtextclassifier2::ConvertIndicesBMPToUTF8;
using libtextclassifier2::ConvertIndicesUTF8ToBMP;
using libtextclassifier2::FromJavaAnnotationOptions;
using libtextclassifier2::FromJavaClassificationOptions;
using libtextclassifier2::FromJavaSelectionOptions;
using libtextclassifier2::ToStlString;

JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
(JNIEnv* env, jobject thiz, jint fd) {
#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
  return reinterpret_cast<jlong>(
      TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env));
#else
  return reinterpret_cast<jlong>(
      TextClassifier::FromFileDescriptor(fd).release());
#endif
}

JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
  const std::string path_str = ToStlString(env, path);
#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
  return reinterpret_cast<jlong>(
      TextClassifier::FromPath(path_str, new UniLib(env)).release());
#else
  return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
#endif
}

JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
  return reinterpret_cast<jlong>(
      TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env))
          .release());
#else
  return reinterpret_cast<jlong>(
      TextClassifier::FromFileDescriptor(fd, offset, size).release());
#endif
}

JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
 jint selection_end, jobject options) {
  if (!ptr) {
    return nullptr;
  }

  TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);

  const std::string context_utf8 = ToStlString(env, context);
  CodepointSpan input_indices =
      ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
  CodepointSpan selection = model->SuggestSelection(
      context_utf8, input_indices, FromJavaSelectionOptions(env, options));
  selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);

  jintArray result = env->NewIntArray(2);
  env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
  env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
  return result;
}

JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
 jint selection_end, jobject options) {
  if (!ptr) {
    return nullptr;
  }
  TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);

  const std::string context_utf8 = ToStlString(env, context);
  const CodepointSpan input_indices =
      ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
  const std::vector<ClassificationResult> classification_result =
      ff_model->ClassifyText(context_utf8, input_indices,
                             FromJavaClassificationOptions(env, options));

  return ClassificationResultsToJObjectArray(env, classification_result);
}

JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
  if (!ptr) {
    return nullptr;
  }
  TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
  std::string context_utf8 = ToStlString(env, context);
  std::vector<AnnotatedSpan> annotations =
      model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));

  jclass result_class =
      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
  if (!result_class) {
    TC_LOG(ERROR) << "Couldn't find result class: "
                  << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan";
    return nullptr;
  }

  jmethodID result_class_constructor = env->GetMethodID(
      result_class, "<init>",
      "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V");

  jobjectArray results =
      env->NewObjectArray(annotations.size(), result_class, nullptr);

  for (int i = 0; i < annotations.size(); ++i) {
    CodepointSpan span_bmp =
        ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
    jobject result = env->NewObject(
        result_class, result_class_constructor,
        static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
        ClassificationResultsToJObjectArray(env,

                                            annotations[i].classification));
    env->SetObjectArrayElement(results, i, result);
    env->DeleteLocalRef(result);
  }
  env->DeleteLocalRef(result_class);
  return results;
}

JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
(JNIEnv* env, jobject thiz, jlong ptr) {
  TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
  delete model;
}

JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
(JNIEnv* env, jobject clazz, jint fd) {
  TC_LOG(WARNING) << "Using deprecated getLanguage().";
  return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd);
}

JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
(JNIEnv* env, jobject clazz, jint fd) {
  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
      new libtextclassifier2::ScopedMmap(fd));
  return GetLocalesFromMmap(env, mmap.get());
}

JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
      new libtextclassifier2::ScopedMmap(fd, offset, size));
  return GetLocalesFromMmap(env, mmap.get());
}

JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
(JNIEnv* env, jobject clazz, jint fd) {
  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
      new libtextclassifier2::ScopedMmap(fd));
  return GetVersionFromMmap(env, mmap.get());
}

JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
      new libtextclassifier2::ScopedMmap(fd, offset, size));
  return GetVersionFromMmap(env, mmap.get());
}

JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd) {
  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
      new libtextclassifier2::ScopedMmap(fd));
  return GetNameFromMmap(env, mmap.get());
}

JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor)
(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
      new libtextclassifier2::ScopedMmap(fd, offset, size));
  return GetNameFromMmap(env, mmap.get());
}